diff --git a/.devops/cann.Dockerfile b/.devops/cann.Dockerfile index 9df86d0489..dc95e3f38d 100644 --- a/.devops/cann.Dockerfile +++ b/.devops/cann.Dockerfile @@ -13,6 +13,20 @@ ARG APP_REVISION=N/A # BUILD STAGE # Compile all binary files and libraries # ============================================================================== +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${CANN_BASE_IMAGE} AS build # -- Install build dependencies -- @@ -26,6 +40,8 @@ WORKDIR /app # -- Copy project files -- COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + # -- Set CANN environment variables (required for compilation) -- # Using ENV instead of `source` allows environment variables to persist across the entire image layer ENV ASCEND_TOOLKIT_HOME=/usr/local/Ascend/ascend-toolkit/latest diff --git a/.devops/cpu.Dockerfile b/.devops/cpu.Dockerfile index 9dbfdd11df..caf727bcdb 100644 --- a/.devops/cpu.Dockerfile +++ b/.devops/cpu.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build ARG TARGETARCH @@ -16,6 +30,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "$TARGETARCH" = "amd64" ] || [ "$TARGETARCH" = "arm64" ]; then \ cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON; \ else \ diff --git a/.devops/cuda.Dockerfile b/.devops/cuda.Dockerfile index 276f82c34c..b16b9a8f1a 100644 --- a/.devops/cuda.Dockerfile +++ b/.devops/cuda.Dockerfile @@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${BASE_CUDA_DEV_CONTAINER} AS build ARG GCC_VERSION @@ -26,6 +40,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${CUDA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DCMAKE_CUDA_ARCHITECTURES=${CUDA_DOCKER_ARCH}"; \ fi && \ diff --git a/.devops/intel.Dockerfile b/.devops/intel.Dockerfile index 4d0c0a8fd8..3c059eb301 100644 --- a/.devops/intel.Dockerfile +++ b/.devops/intel.Dockerfile @@ -5,6 +5,20 @@ ARG APP_REVISION=N/A ## Build Image +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/intel/deep-learning-essentials:$ONEAPI_VERSION AS build ARG GGML_SYCL_F16=ON @@ -22,6 +36,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${GGML_SYCL_F16}" = "ON" ]; then \ echo "GGML_SYCL_F16 is set" \ && export OPT_SYCL_F16="-DGGML_SYCL_F16=ON" \ diff --git a/.devops/musa.Dockerfile b/.devops/musa.Dockerfile index c98c44c951..0c23cc5547 100644 --- a/.devops/musa.Dockerfile +++ b/.devops/musa.Dockerfile @@ -10,6 +10,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM ${BASE_MUSA_DEV_CONTAINER} AS build # MUSA architecture to build for (defaults to all supported archs) @@ -29,6 +43,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN if [ "${MUSA_DOCKER_ARCH}" != "default" ]; then \ export CMAKE_ARGS="-DMUSA_ARCHITECTURES=${MUSA_DOCKER_ARCH}"; \ fi && \ diff --git a/.devops/openvino.Dockerfile b/.devops/openvino.Dockerfile index 9e96244ced..fec72b1c7d 100644 --- a/.devops/openvino.Dockerfile +++ b/.devops/openvino.Dockerfile @@ -22,6 +22,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + ## Build Image FROM docker.io/ubuntu:${UBUNTU_VERSION} AS build @@ -69,6 +83,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + # Build Stage RUN bash -c "source ${OpenVINO_DIR}/setupvars.sh && \ cmake -B build/ReleaseOV -G Ninja \ diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 2ab10f4117..7fad0c22e5 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -11,6 +11,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + ### Build image FROM ${BASE_ROCM_DEV_CONTAINER} AS build @@ -38,6 +52,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN HIPCXX="$(hipconfig -l)/clang" HIP_PATH="$(hipconfig -R)" \ cmake -S . -B build \ -DGGML_HIP=ON \ diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index 05df94ec44..26c1902b14 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build # Install build tools @@ -17,6 +31,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN cmake -B build -DGGML_NATIVE=OFF -DGGML_VULKAN=ON -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON && \ cmake --build build --config Release -j$(nproc) diff --git a/.devops/zendnn.Dockerfile b/.devops/zendnn.Dockerfile index 9f811ab278..80daf56710 100644 --- a/.devops/zendnn.Dockerfile +++ b/.devops/zendnn.Dockerfile @@ -3,6 +3,20 @@ ARG BUILD_DATE=N/A ARG APP_VERSION=N/A ARG APP_REVISION=N/A +ARG NODE_VERSION=24 + +FROM docker.io/node:$NODE_VERSION AS web + +ARG APP_VERSION + +WORKDIR /app/tools/ui + +COPY tools/ui/package.json tools/ui/package-lock.json ./ +RUN npm ci + +COPY tools/ui/ ./ +RUN LLAMA_BUILD_NUMBER="$APP_VERSION" npm run build + FROM docker.io/ubuntu:$UBUNTU_VERSION AS build RUN apt-get update && \ @@ -14,6 +28,8 @@ WORKDIR /app COPY . . +COPY --from=web /app/tools/ui/dist tools/ui/dist + RUN cmake -S . -B build -DCMAKE_BUILD_TYPE=Release -DGGML_NATIVE=OFF -DLLAMA_BUILD_TESTS=OFF -DGGML_BACKEND_DL=ON -DGGML_CPU_ALL_VARIANTS=ON -DGGML_ZENDNN=ON && \ cmake --build build -j $(nproc) diff --git a/.dockerignore b/.dockerignore index 064b7c7be8..0b81e83bf5 100644 --- a/.dockerignore +++ b/.dockerignore @@ -10,6 +10,8 @@ build*/ +tools/ui/node_modules/ + models/* /llama-cli diff --git a/.github/labeler.yml b/.github/labeler.yml index 3361118ed9..bf994928f9 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -35,8 +35,20 @@ AMD ZenDNN: documentation: - changed-files: - any-glob-to-any-file: + - **/*.md - docs/** - media/** +examples: + - all: + - changed-files: + - any-glob-to-any-file: + - app/** + - examples/** + - tools/** + - all-globs-to-all-files: + - '!tools/server/**' + - '!tools/mtmd/**' + - '!tools/ui/**' testing: - changed-files: - any-glob-to-any-file: @@ -47,28 +59,12 @@ build: - cmake/** - CMakeLists.txt - CMakePresets.json -examples: - - changed-files: - - any-glob-to-any-file: - - examples/** - - tools/** devops: - changed-files: - any-glob-to-any-file: - .devops/** - .github/** - ci/** -python: - - changed-files: - - any-glob-to-any-file: - - "**/*.py" - - requirements/** - - gguf-py/** - - .flake8 -script: - - changed-files: - - any-glob-to-any-file: - - scripts/** android: - changed-files: - any-glob-to-any-file: @@ -81,9 +77,20 @@ server: - changed-files: - any-glob-to-any-file: - tools/server/** - - - +mtmd: + - changed-files: + - any-glob-to-any-file: + - tools/mtmd/** +conversion: + - changed-files: + - any-glob-to-any-file: + - conversion/** + - convert_*.py + - gguf-py/** +vendor: + - changed-files: + - any-glob-to-any-file: + - vendor/** ggml: - changed-files: - any-glob-to-any-file: diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 8195a55ff2..afe4b7c664 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -58,6 +58,13 @@ jobs: git tag ${{ steps.srctag.outputs.name }} || exit 0 git push origin ${{ steps.srctag.outputs.name }} || exit 0 + build_ui: + name: Build UI + needs: create_tag + uses: ./.github/workflows/ui-build.yml + with: + hf_ui_version: ${{ needs.create_tag.outputs.source_tag }} + prepare_matrices: name: Prepare Docker matrices runs-on: ubuntu-24.04 @@ -79,7 +86,7 @@ jobs: [ { "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04" }, { "tag": "cpu", "dockerfile": ".devops/cpu.Dockerfile", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-arm" }, - { "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x" }, + { "tag": "cpu", "dockerfile": ".devops/s390x.Dockerfile", "platforms": "linux/s390x", "full": true, "light": true, "server": true, "free_disk_space": false, "runs_on": "ubuntu-24.04-s390x", "prebuilt_ui": true }, { "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" }, { "tag": "cuda cuda12", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "12.8.1", "platforms": "linux/arm64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04-arm" }, { "tag": "cuda13", "dockerfile": ".devops/cuda.Dockerfile", "cuda_version": "13.3.0", "platforms": "linux/amd64", "full": true, "light": true, "server": true, "free_disk_space": true, "runs_on": "ubuntu-24.04" }, @@ -135,7 +142,7 @@ jobs: push_to_registry: name: Push Docker image to Docker Registry - needs: [prepare_matrices, create_tag] + needs: [prepare_matrices, create_tag, build_ui] runs-on: ${{ matrix.config.runs_on }} strategy: @@ -150,6 +157,13 @@ jobs: fetch-depth: 0 ref: ${{ needs.create_tag.outputs.source_tag }} + - name: Download prebuilt UI + if: ${{ matrix.config.prebuilt_ui == true }} + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8 + with: + name: ui-build + path: tools/ui/dist + - name: Set up QEMU if: ${{ contains(matrix.config.platforms, 'linux/amd64') }} uses: docker/setup-qemu-action@ce360397dd3f832beb865e1373c09c0e9f86d70a # v4 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9789215f29..c7b67e4925 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1627,6 +1627,7 @@ jobs: **Windows:** - [Windows x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-x64.zip) - [Windows arm64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cpu-arm64.zip) + - [Windows arm64 (OpenCL Adreno)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-opencl-adreno-arm64.zip) - [Windows x64 (CUDA 12)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-12.4-x64.zip) - [CUDA 12.4 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-12.4-x64.zip) - [Windows x64 (CUDA 13)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-cuda-13.3-x64.zip) - [CUDA 13.3 DLLs](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/cudart-llama-bin-win-cuda-13.3-x64.zip) - [Windows x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-win-vulkan-x64.zip) diff --git a/.pi/gg/SYSTEM.md b/.pi/gg/SYSTEM.md index 197173faed..17ce71cc1b 100644 --- a/.pi/gg/SYSTEM.md +++ b/.pi/gg/SYSTEM.md @@ -25,13 +25,3 @@ Commits: - Do not explicitly set the git author in commits - rely on the default git config - Always use `--no-gpg-sign` when committing - Never `git push` without explicit confirmation from the user - -Resources (read on demand): -- [CONTRIBUTING.md](CONTRIBUTING.md) -- [Build documentation](docs/build.md) -- [Server usage documentation](tools/server/README.md) -- [Server development documentation](tools/server/README-dev.md) -- [PEG parser](docs/development/parsing.md) -- [Auto parser](docs/autoparser.md) -- [Jinja engine](common/jinja/README.md) -- [PR template](.github/pull_request_template.md) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9e7b1253c7..81f23d7e70 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,6 +222,16 @@ if (LLAMA_BUILD_APP) add_subdirectory(app) endif() +# Standalone libmtmd build without pulling in the rest of the tools/ tree. +# Useful when packaging just the mtmd library for language bindings (e.g. an +# Apple XCFramework, or a WASM build). When the full tools build is enabled, +# mtmd is already built by the tools/ subdirectory above; this hook only fires +# when LLAMA_BUILD_TOOLS is OFF to avoid double-adding the target. +option(LLAMA_BUILD_MTMD "llama: build tools/mtmd library standalone" OFF) +if (LLAMA_BUILD_MTMD AND NOT (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TOOLS)) + add_subdirectory(tools/mtmd) +endif() + # # install # diff --git a/CODEOWNERS b/CODEOWNERS index 4b9d901771..46fd518b7e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -10,7 +10,7 @@ # ggml-org/ggml-rpc : rgerganov # ggml-org/ggml-sycl : arthw # ggml-org/ggml-vulkan : 0cc4m, jeffbolznv -# ggml-org/ggml-webgpu : reeselevine +# ggml-org/ggml-webgpu : reeselevine, yomaytk # ggml-org/ggml-zdnn : taronaeo # ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin # ggml-org/llama-mtmd : ngxson diff --git a/README.md b/README.md index 0652d13f29..e98f2b7f18 100644 --- a/README.md +++ b/README.md @@ -142,7 +142,9 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo - [x] [GigaChat-20B-A3B](https://huggingface.co/ai-sage/GigaChat-20B-A3B-instruct) - [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview) - [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) -- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) +- [x] [Liquid LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2) +- [x] [Liquid LFM2.5 models](https://huggingface.co/collections/LiquidAI/lfm25) +- [x] [Liquid Nanos](https://huggingface.co/collections/LiquidAI/liquid-nanos) - [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7) - [x] [BailingMoeV2 (Ring/Ling 2.0) models](https://huggingface.co/collections/inclusionAI/ling-v2-68bf1dd2fc34c306c1fa6f86) - [x] [Mellum models](https://huggingface.co/JetBrains/models?search=mellum) diff --git a/app/CMakeLists.txt b/app/CMakeLists.txt index 3ce503955b..3450ff4900 100644 --- a/app/CMakeLists.txt +++ b/app/CMakeLists.txt @@ -1,6 +1,6 @@ set(TARGET llama-app) -add_executable(${TARGET} llama.cpp) +add_executable(${TARGET} llama.cpp download.cpp) set_target_properties(${TARGET} PROPERTIES OUTPUT_NAME llama) target_link_libraries(${TARGET} PRIVATE diff --git a/app/download.cpp b/app/download.cpp new file mode 100644 index 0000000000..7227baadcb --- /dev/null +++ b/app/download.cpp @@ -0,0 +1,71 @@ +#include "arg.h" +#include "common.h" +#include "download.h" +#include "log.h" + +#include +#include + +static void print_usage(int /*argc*/, char ** argv) { + printf( + "\nexamples:\n" + " %s -hf ggml-org/gemma-3-4b-it-qat-GGUF\n" + " %s -hf ggml-org/gemma-3-4b-it-qat-GGUF:Q4_K_M\n" + " %s -hf ggml-org/models -hff model.gguf\n" + " %s -mu https://example.com/model.gguf -m model.gguf\n" + "\n", + argv[0], argv[0], argv[0], argv[0] + ); +} + +int llama_download(int argc, char ** argv); + +int llama_download(int argc, char ** argv) { + common_init(); + + common_params params; + params.verbosity = LOG_LEVEL_ERROR; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DOWNLOAD, print_usage)) { + return 1; + } + + const bool has_source = !params.model.hf_repo.empty() || !params.model.url.empty() || + !params.model.path.empty() || !params.model.docker_repo.empty(); + if (!has_source) { + fprintf(stderr, "error: no model source specified (use --hf-repo, --model-url, --model or --docker-repo)\n"); + return 1; + } + + try { + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_DOWNLOAD); + common_models_handler_apply(handler, params); + } catch (const std::exception & e) { + fprintf(stderr, "error: %s\n", e.what()); + return 1; + } + + if (!params.models_preset.empty()) { + // -hf pointed at a preset repo: print the preset path and stop + printf("%s\n", params.models_preset.c_str()); + return 0; + } + if (params.model.path.empty()) { + fprintf(stderr, "error: model download failed\n"); + return 1; + } + if (!std::filesystem::exists(params.model.path)) { + fprintf(stderr, "error: model file does not exist: %s\n", params.model.path.c_str()); + return 1; + } + + printf("%s\n", params.model.path.c_str()); + if (!params.mmproj.path.empty()) { + printf("%s\n", params.mmproj.path.c_str()); + } + if (!params.speculative.draft.mparams.path.empty()) { + printf("%s\n", params.speculative.draft.mparams.path.c_str()); + } + + return 0; +} diff --git a/app/llama.cpp b/app/llama.cpp index c4578ea53b..00babbc7b4 100644 --- a/app/llama.cpp +++ b/app/llama.cpp @@ -19,6 +19,7 @@ int llama_batched_bench(int argc, char ** argv); int llama_fit_params(int argc, char ** argv); int llama_quantize(int argc, char ** argv); int llama_perplexity(int argc, char ** argv); +int llama_download(int argc, char ** argv); // Self-update is only supported for binaries built with llama-install.sh static int llama_update(int argc, char ** argv) { @@ -61,6 +62,7 @@ static const command cmds[] = { {"serve", "HTTP API server", {"server"}, false, llama_server }, {"cli", "Command-line interactive interface", {"client"}, false, llama_cli }, {"update", "Update llama to the latest release", {}, UPDATE_HIDDEN, llama_update }, + {"download", "Download a model", {"get"}, false, llama_download }, {"completion", "Text completion", {"complete"}, true, llama_completion }, {"bench", "Benchmark prompt processing and text generation", {}, true, llama_bench }, {"batched-bench", "Benchmark batched decoding performance", {}, true, llama_batched_bench}, diff --git a/build-xcframework.sh b/build-xcframework.sh index 180c01a88e..3a265e53ee 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -13,6 +13,7 @@ LLAMA_BUILD_EXAMPLES=OFF LLAMA_BUILD_TOOLS=OFF LLAMA_BUILD_TESTS=OFF LLAMA_BUILD_SERVER=OFF +LLAMA_BUILD_MTMD=ON GGML_METAL=ON GGML_METAL_EMBED_LIBRARY=ON GGML_BLAS_DEFAULT=ON @@ -39,6 +40,7 @@ COMMON_CMAKE_ARGS=( -DLLAMA_BUILD_TOOLS=${LLAMA_BUILD_TOOLS} -DLLAMA_BUILD_TESTS=${LLAMA_BUILD_TESTS} -DLLAMA_BUILD_SERVER=${LLAMA_BUILD_SERVER} + -DLLAMA_BUILD_MTMD=${LLAMA_BUILD_MTMD} -DGGML_METAL_EMBED_LIBRARY=${GGML_METAL_EMBED_LIBRARY} -DGGML_BLAS_DEFAULT=${GGML_BLAS_DEFAULT} -DGGML_METAL=${GGML_METAL} @@ -126,6 +128,8 @@ setup_framework_structure() { cp ggml/include/ggml-cpu.h ${header_path} cp ggml/include/ggml-blas.h ${header_path} cp ggml/include/gguf.h ${header_path} + cp tools/mtmd/mtmd.h ${header_path} + cp tools/mtmd/mtmd-helper.h ${header_path} # Create module map (common for all platforms) cat > ${module_path}module.modulemap << EOF @@ -247,6 +251,7 @@ combine_static_libraries() { "${base_dir}/${build_dir}/ggml/src/${release_dir}/libggml-cpu.a" "${base_dir}/${build_dir}/ggml/src/ggml-metal/${release_dir}/libggml-metal.a" "${base_dir}/${build_dir}/ggml/src/ggml-blas/${release_dir}/libggml-blas.a" + "${base_dir}/${build_dir}/tools/mtmd/${release_dir}/libmtmd.a" ) # Create temporary directory for processing diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index c42320c46b..fc16b21cf1 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -80,8 +80,6 @@ add_library(${TARGET} http.h imatrix-loader.cpp imatrix-loader.h - json-partial.cpp - json-partial.h json-schema-to-grammar.cpp llguidance.cpp log.cpp diff --git a/common/arg.cpp b/common/arg.cpp index bd4b113d16..841a38e961 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -17,6 +17,7 @@ # define NOMINMAX #endif #include +#include #endif #define JSON_ASSERT GGML_ASSERT @@ -296,60 +297,6 @@ struct handle_model_result { std::string preset_path; }; -static handle_model_result common_params_handle_model(struct common_params_model & model, - const common_download_opts & opts) { - handle_model_result result; - - if (!model.docker_repo.empty()) { - model.path = common_docker_resolve_model(model.docker_repo); - model.name = model.docker_repo; - } else if (!model.hf_repo.empty()) { - // If -m was used with -hf, treat the model "path" as the hf_file to download - if (model.hf_file.empty() && !model.path.empty()) { - model.hf_file = model.path; - model.path = ""; - } - common_download_opts hf_opts = opts; - auto download_result = common_download_model(model, hf_opts); - - if (!download_result.preset_path.empty()) { - result.found_preset = true; - result.preset_path = download_result.preset_path; - return result; // skip everything else if preset.ini is used - } - - if (download_result.model_path.empty()) { - throw std::runtime_error("failed to download model from Hugging Face"); - } - - model.name = model.hf_repo; - model.path = download_result.model_path; - - if (!download_result.mmproj_path.empty()) { - result.found_mmproj = true; - result.mmproj.path = download_result.mmproj_path; - } - - if (!download_result.mtp_path.empty()) { - result.found_mtp = true; - result.mtp.path = download_result.mtp_path; - } - } else if (!model.url.empty()) { - if (model.path.empty()) { - auto f = string_split(model.url, '#').front(); - f = string_split(f, '?').front(); - model.path = fs_get_cache_file(string_split(f, '/').back()); - } - - auto download_result = common_download_model(model, opts); - if (download_result.model_path.empty()) { - throw std::runtime_error("failed to download model from " + model.url); - } - } - - return result; -} - const std::vector kv_cache_types = { GGML_TYPE_F32, GGML_TYPE_F16, @@ -394,72 +341,204 @@ static bool parse_bool_value(const std::string & value) { } // -// CLI argument parsing functions +// common_models_handler // -bool common_params_handle_models(common_params & params, llama_example curr_ex) { - const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), - params.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); +static std::string get_default_local_path(const std::string & url) { + auto f = string_split(url, '#').front(); + f = string_split(f, '?').front(); + return fs_get_cache_file(string_split(f, '/').back()); +} +common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex) { + common_download_hf_plan plan; common_download_opts opts; + + const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), + params.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); + + // only download mmproj if the current example is using it + bool use_mmproj = false; + for (const auto & ex : mmproj_examples) { + if (curr_ex == ex) { + use_mmproj = true; + break; + } + } + opts.bearer_token = params.hf_token; opts.offline = params.offline; - opts.skip_download = params.skip_download; opts.download_mtp = spec_type_draft_mtp; - opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + opts.download_mmproj = use_mmproj && !params.no_mmproj + && params.mmproj.path.empty() && params.mmproj.url.empty(); - // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, - // so we should not auto-discover mtp/mmproj siblings for them - common_download_opts sub_opts = opts; - sub_opts.download_mtp = false; - sub_opts.download_mmproj = false; + if (!params.model.hf_repo.empty()) { + plan = common_download_get_hf_plan(params.model, opts); + } - try { - auto res = common_params_handle_model(params.model, opts); - if (res.found_preset) { - if (!params.models_preset.empty()) { - throw std::invalid_argument("cannot use both --models-preset and -hf with a preset.ini file"); + return common_models_handler{plan, opts}; +} + +bool common_models_handler_is_preset_repo(const common_models_handler & handler) { + return !handler.plan.preset.url.empty(); +} + +static std::vector build_url_tasks(const common_params_model & model, common_download_opts opts) { + auto parts = common_download_get_all_parts(model.url); + std::vector tasks; + + // single-part: download straight to model.path if the user gave one (-m), else the cache default + if (parts.size() == 1) { + common_download_task task; + task.url = parts[0]; + task.local_path = model.path.empty() ? get_default_local_path(parts[0]) : model.path; + task.opts = opts; + tasks.push_back(std::move(task)); + return tasks; + } + + // multi-part: place each part under the user's -m directory (if given), else the cache default + std::string base_dir; + if (!model.path.empty()) { + auto pos = model.path.rfind('/'); + base_dir = pos == std::string::npos ? std::string(".") : model.path.substr(0, pos); + } + + for (const auto & part : parts) { + common_download_task task; + task.url = part; + task.opts = opts; + + std::string local = get_default_local_path(part); + if (!base_dir.empty()) { + auto pos = local.rfind('/'); + std::string name = pos == std::string::npos ? local : local.substr(pos + 1); + local = base_dir + "/" + name; + } + task.local_path = local; + tasks.push_back(std::move(task)); + } + return tasks; +} + +void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback) { + std::vector tasks; + + auto & plan = handler.plan; + + auto opts = handler.opts; // copy + opts.callback = callback; + + // handle plain "url" if needed + auto handle_url = [&](common_params_model & model) { + if (!model.url.empty()) { + if (model.path.empty()) { + model.path = get_default_local_path(model.url); } + } + }; + handle_url(params.model); + handle_url(params.mmproj); + handle_url(params.vocoder.model); + handle_url(params.speculative.draft.mparams); + + // optionally, if docker repo is set, resolve it + if (!params.model.docker_repo.empty()) { + params.model.url = common_docker_resolve_model(params.model.docker_repo); + params.model.path = get_default_local_path(params.model.url); + } + + // handle plain "url" tasks (non-hf) + if (!params.model.url.empty()) { + auto url_tasks = build_url_tasks(params.model, opts); + // the first part is what gets loaded, so point params.model.path at it + if (!url_tasks.empty()) { + std::string first_path = url_tasks.front().local_path; + url_tasks.front().on_done = [&]() { params.model.path = first_path; }; + } + for (auto & task : url_tasks) { + tasks.push_back(std::move(task)); + } + } + if (!params.mmproj.url.empty()) { + common_download_task task; + task.url = params.mmproj.url; + task.local_path = params.mmproj.path; + task.opts = opts; + tasks.push_back(task); + } + if (!params.vocoder.model.url.empty()) { + common_download_task task; + task.url = params.vocoder.model.url; + task.local_path = params.vocoder.model.path; + task.opts = opts; + tasks.push_back(task); + } + if (!params.speculative.draft.mparams.url.empty()) { + common_download_task task; + task.url = params.speculative.draft.mparams.url; + task.local_path = params.speculative.draft.mparams.path; + task.opts = opts; + tasks.push_back(task); + } + + // handle hf_plan tasks + if (!plan.model_files.empty()) { + for (size_t i = 0; i < plan.model_files.size(); ++i) { + auto & model_file = plan.model_files[i]; + bool is_first = (i == 0); + tasks.emplace_back(model_file, opts, [&, is_first]() { + if (is_first) { + // only use first part as model path + params.model.path = hf_cache::finalize_file(model_file); + } else { + hf_cache::finalize_file(model_file); + } + }); + } + } + if (!plan.mmproj.local_path.empty()) { + tasks.emplace_back(plan.mmproj, opts, [&]() { + params.mmproj.path = hf_cache::finalize_file(plan.mmproj); + }); + } + if (!plan.mtp.local_path.empty()) { + tasks.emplace_back(plan.mtp, opts, [&]() { + // only fall back to the discovered MTP head when no draft was explicitly provided + if (params.speculative.draft.mparams.empty()) { + params.speculative.draft.mparams.path = hf_cache::finalize_file(plan.mtp); + } else { + hf_cache::finalize_file(plan.mtp); + } + }); + } + if (!plan.preset.local_path.empty()) { + tasks.emplace_back(plan.preset, opts, [&]() { // if HF repo is a preset repo, we simply run server in router mode with the preset.ini file params.models_preset_hf = params.model.hf_repo; // only for showing a warning - params.models_preset = res.preset_path; + params.models_preset = hf_cache::finalize_file(plan.preset); params.model = common_params_model{}; // make sure to clear model, so server starts in router mode - return true; - } + }); + } - if (params.no_mmproj) { - params.mmproj = {}; - } else if (res.found_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty()) { - // optionally, handle mmproj model when -hf is specified - params.mmproj = res.mmproj; - } - // only download mmproj if the current example is using it - for (const auto & ex : mmproj_examples) { - if (curr_ex == ex) { - common_params_handle_model(params.mmproj, sub_opts); - break; - } - } + // run all tasks in parallel + if (!params.offline) { + common_download_run_tasks(tasks); + } - // when --spec-type mtp is set and no draft model was provided explicitly, - // fall back to the MTP head discovered alongside the -hf model - if (spec_type_draft_mtp && res.found_mtp && - params.speculative.draft.mparams.path.empty() && - params.speculative.draft.mparams.hf_repo.empty() && - params.speculative.draft.mparams.url.empty()) { - params.speculative.draft.mparams.path = res.mtp.path; + // download successful, update params with the downloaded paths + for (const auto & task : tasks) { + if (task.on_done) { + task.on_done(); } - common_params_handle_model(params.speculative.draft.mparams, sub_opts); - common_params_handle_model(params.vocoder.model, sub_opts); - return true; - } catch (const common_skip_download_exception &) { - return false; - } catch (const std::exception &) { - throw; } } +// +// CLI argument parsing functions +// + static bool common_params_parse_ex(int argc, char ** argv, common_params_context & ctx_arg) { common_params & params = ctx_arg.params; @@ -585,17 +664,22 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context throw std::invalid_argument("error: --prompt-cache-all not supported in interactive mode yet\n"); } - // export_graph_ops loads only metadata - const bool skip_model_download = ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; + const bool skip_model_download = + // server will call common_params_handle_models() later, so we skip it here + ctx_arg.ex == LLAMA_EXAMPLE_SERVER || + // download calls common_params_handle_models() itself and prints the paths + ctx_arg.ex == LLAMA_EXAMPLE_DOWNLOAD || + // export_graph_ops loads only metadata + ctx_arg.ex == LLAMA_EXAMPLE_EXPORT_GRAPH_OPS; if (!skip_model_download) { // handle model and download - common_params_handle_models(params, ctx_arg.ex); + common_models_handler handler = common_models_handler_init(params, ctx_arg.ex); + common_models_handler_apply(handler, params); // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case if (params.model.path.empty() - && ctx_arg.ex != LLAMA_EXAMPLE_SERVER && !params.usage && !params.completion) { throw std::invalid_argument("error: --model is required\n"); @@ -663,15 +747,19 @@ static void common_params_print_usage(common_params_context & ctx_arg) { common_options.push_back(&opt); } } - printf("----- common params -----\n\n"); - print_options(common_options); - printf("\n\n----- sampling params -----\n\n"); - print_options(sampling_options); - printf("\n\n----- speculative params -----\n\n"); - print_options(spec_options); - // TODO: maybe convert enum llama_example to string - printf("\n\n----- example-specific params -----\n\n"); - print_options(specific_options); + bool first = true; + auto print_section = [&](const char * header, std::vector & options) { + if (options.empty()) { + return; + } + printf("%s----- %s -----\n\n", first ? "" : "\n\n", header); + first = false; + print_options(options); + }; + print_section("common params", common_options); + print_section("sampling params", sampling_options); + print_section("speculative params", spec_options); + print_section("example-specific params", specific_options); } static void common_params_print_completion(common_params_context & ctx_arg) { @@ -893,7 +981,44 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map buf; + std::vector ptrs; +}; + +static utf8_argv make_utf8_argv() { + utf8_argv out; + int wargc = 0; + LPWSTR* wargv = CommandLineToArgvW(GetCommandLineW(), &wargc); + if (!wargv) return out; + + out.buf.reserve(wargc); + for (int i = 0; i < wargc; ++i) { + int n = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS, wargv[i], -1, nullptr, 0, nullptr, nullptr); + if (n <= 0) { out.buf.emplace_back(); continue; } + auto& s = out.buf.emplace_back(); + s.resize(static_cast(n - 1)); + (void)WideCharToMultiByte(CP_UTF8, 0, wargv[i], -1, s.data(), n, nullptr, nullptr); + } + LocalFree(wargv); + + out.ptrs.reserve(out.buf.size() + 1); + for (auto& s : out.buf) out.ptrs.push_back(s.data()); + out.ptrs.push_back(nullptr); + return out; +} +#endif + bool common_params_parse(int argc, char ** argv, common_params & params, llama_example ex, void(*print_usage)(int, char **)) { +#ifdef _WIN32 + auto utf8 = make_utf8_argv(); + // repair argv only when it matches the process command line + if (static_cast(utf8.buf.size()) == argc) { + argv = utf8.ptrs.data(); + } +#endif + auto ctx_arg = common_params_parser_init(params, ex, print_usage); const common_params params_org = ctx_arg.params; // the example can modify the default params @@ -1034,7 +1159,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex * - if both {LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_*,} are set, we will prioritize the LLAMA_EXAMPLE_* matching current example */ auto add_opt = [&](common_arg arg) { - if ((arg.in_example(ex) || arg.in_example(LLAMA_EXAMPLE_COMMON)) && !arg.is_exclude(ex)) { + // download only exposes the handful of args explicitly tagged for it + const bool inherit_common = ex != LLAMA_EXAMPLE_DOWNLOAD; + if ((arg.in_example(ex) || (inherit_common && arg.in_example(LLAMA_EXAMPLE_COMMON))) && !arg.is_exclude(ex)) { ctx_arg.options.push_back(std::move(arg)); } }; @@ -1045,7 +1172,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.usage = true; } - )); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD})); add_opt(common_arg( {"--version"}, "show version and build info", @@ -2167,7 +2294,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.no_mmproj = !value; } - ).set_examples(mmproj_examples).set_env("LLAMA_ARG_MMPROJ_AUTO")); + ).set_examples({LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MMPROJ_AUTO")); add_opt(common_arg( {"--mmproj-offload"}, {"--no-mmproj-offload"}, @@ -2566,14 +2693,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.model.path = value; } - ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA}).set_env("LLAMA_ARG_MODEL")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_EXPORT_LORA, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL")); add_opt(common_arg( {"-mu", "--model-url"}, "MODEL_URL", "model download url (default: unused)", [](common_params & params, const std::string & value) { params.model.url = value; } - ).set_env("LLAMA_ARG_MODEL_URL")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_MODEL_URL")); add_opt(common_arg( { "-dr", "--docker-repo" }, "[/][:quant]", "Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.\n" @@ -2582,7 +2709,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.model.docker_repo = value; } - ).set_env("LLAMA_ARG_DOCKER_REPO")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_DOCKER_REPO")); add_opt(common_arg( {"-hf", "-hfr", "--hf-repo"}, "/[:quant]", "Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.\n" @@ -2592,14 +2719,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.model.hf_repo = value; } - ).set_env("LLAMA_ARG_HF_REPO")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_REPO")); add_opt(common_arg( {"-hff", "--hf-file"}, "FILE", "Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)", [](common_params & params, const std::string & value) { params.model.hf_file = value; } - ).set_env("LLAMA_ARG_HF_FILE")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("LLAMA_ARG_HF_FILE")); add_opt(common_arg( {"-hfv", "-hfrv", "--hf-repo-v"}, "/[:quant]", "Hugging Face model repository for the vocoder model (default: unused)", @@ -2620,7 +2747,14 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, const std::string & value) { params.hf_token = value; } - ).set_env("HF_TOKEN")); + ).set_examples({LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_DOWNLOAD}).set_env("HF_TOKEN")); + add_opt(common_arg( + {"--mtp"}, + "also download the multi-token prediction (MTP) head, if available (default: unused)", + [](common_params & params) { + params.speculative.types.push_back(COMMON_SPECULATIVE_TYPE_DRAFT_MTP); + } + ).set_examples({LLAMA_EXAMPLE_DOWNLOAD})); add_opt(common_arg( {"--context-file"}, "FNAME", "file to load context from (use comma-separated values to specify multiple files)", @@ -2830,62 +2964,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.api_prefix = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_API_PREFIX")); - // Deprecated: use --ui-config instead (kept for backward compat) add_opt(common_arg( - {"--webui-config"}, "JSON", - "[DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)", - [](common_params & params, const std::string & value) { - params.ui_config_json = value; - params.webui_config_json = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG")); - - add_opt(common_arg( - {"--ui-config"}, "JSON", + {"--ui-config", "--webui-config"}, "JSON", "JSON that provides default UI settings (overrides UI defaults)", [](common_params & params, const std::string & value) { params.ui_config_json = value; - params.webui_config_json = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG")); - - // Deprecated: use --ui-config-file instead (kept for backward compat) add_opt(common_arg( - {"--webui-config-file"}, "PATH", - "[DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)", - [](common_params & params, const std::string & value) { - params.ui_config_json = read_file(value); - params.webui_config_json = params.ui_config_json; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_CONFIG_FILE")); - - add_opt(common_arg( - {"--ui-config-file"}, "PATH", + {"--ui-config-file", "--webui-config-file"}, "PATH", "JSON file that provides default UI settings (overrides UI defaults)", [](common_params & params, const std::string & value) { params.ui_config_json = read_file(value); - params.webui_config_json = params.ui_config_json; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_CONFIG_FILE")); - - // Deprecated: use --ui-mcp-proxy instead (kept for backward compat) add_opt(common_arg( - {"--webui-mcp-proxy"}, - {"--no-webui-mcp-proxy"}, - "[DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy", - [](common_params & params, bool value) { - params.ui_mcp_proxy = value; - params.webui_mcp_proxy = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI_MCP_PROXY")); - - add_opt(common_arg( - {"--ui-mcp-proxy"}, - {"--no-ui-mcp-proxy"}, + {"--ui-mcp-proxy", "--webui-mcp-proxy"}, + {"--no-ui-mcp-proxy", "--no-webui-mcp-proxy"}, "experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)", [](common_params & params, bool value) { params.ui_mcp_proxy = value; - params.webui_mcp_proxy = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI_MCP_PROXY")); add_opt(common_arg( @@ -2897,24 +2995,26 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.server_tools = parse_csv_row(value); } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TOOLS")); - // Deprecated: use --ui/--no-ui instead (kept for backward compat) add_opt(common_arg( - {"--webui"}, - {"--no-webui"}, - "[DEPRECATED: use --ui/--no-ui] whether to enable the Web UI", + {"-ag", "--agent"}, + {"-no-ag", "--no-agent"}, + "whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)", [](common_params & params, bool value) { - params.ui = value; - params.webui = value; + if (value) { + params.server_tools = {"all"}; + params.ui_mcp_proxy = true; + } else { + params.server_tools.clear(); + params.ui_mcp_proxy = false; + } } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_WEBUI")); - + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_AGENT")); add_opt(common_arg( - {"--ui"}, - {"--no-ui"}, + {"--ui", "--webui"}, + {"--no-ui", "--no-webui"}, string_format("whether to enable the Web UI (default: %s)", params.ui ? "enabled" : "disabled"), [](common_params & params, bool value) { params.ui = value; - params.webui = value; } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_UI")); add_opt(common_arg( @@ -2945,7 +3045,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( {"--api-key-file"}, "FNAME", - "path to file containing API keys (default: none)", + "path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)", [](common_params & params, const std::string & value) { std::ifstream key_file(value); if (!key_file) { @@ -2953,7 +3053,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } std::string key; while (std::getline(key_file, key)) { - if (!key.empty()) { + if (!key.empty() && key[0] != '#') { params.api_keys.push_back(key); } } diff --git a/common/arg.h b/common/arg.h index 0010f2a9ac..508e33d29e 100644 --- a/common/arg.h +++ b/common/arg.h @@ -1,12 +1,14 @@ #pragma once #include "common.h" +#include "download.h" #include #include #include #include #include +#include // pseudo-env variable to identify preset-only arguments #define COMMON_ARG_PRESET_LOAD_ON_STARTUP "__PRESET_LOAD_ON_STARTUP" @@ -129,11 +131,19 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); -// populate model paths (main model, mmproj, etc) from -hf if necessary -// return true if the model is ready to use -// throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) -// if params.skip_download is true, no downloads will be attempted. return false if the model is invalid or missing (e.g. ETag check failed) -bool common_params_handle_models(common_params & params, llama_example curr_ex); +struct common_models_handler { + common_download_hf_plan plan; + common_download_opts opts; +}; + +// initialize downloading opts and hf_plan if needed, but does not download anything yet +common_models_handler common_models_handler_init(const common_params & params, llama_example curr_ex); + +// check if the model is a preset repo (i.e. has a preset file) +bool common_models_handler_is_preset_repo(const common_models_handler & handler); + +// download and update params with the downloaded model path +void common_models_handler_apply(common_models_handler & handler, common_params & params, common_download_callback * callback = nullptr); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/chat-auto-parser-generator.cpp b/common/chat-auto-parser-generator.cpp index 37ca55c8df..36aab7ecbe 100644 --- a/common/chat-auto-parser-generator.cpp +++ b/common/chat-auto-parser-generator.cpp @@ -395,10 +395,11 @@ common_peg_parser analyze_tools::build_tool_parser_tag_tagged(parser_build_conte arguments.name_suffix) + arguments.value_prefix + (schema_info.resolves_to_string(param_schema) ? - p.tool_arg_string_value(until_suffix) : - p.tool_arg_json_value(p.schema( - p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false))) + - p.tool_arg_close(p.literal(arguments.value_suffix))); + p.ac(p.tool_arg_string_value(until_suffix) + + p.tool_arg_close(p.literal(arguments.value_suffix)), arguments.value_suffix) : + (p.tool_arg_json_value(p.schema( + p.json(), "tool-" + name + "-arg-" + param_name + "-schema", param_schema, false)) + + p.tool_arg_close(p.literal(arguments.value_suffix))))); auto named_arg = p.rule("tool-" + name + "-arg-" + param_name, arg); if (is_required) { diff --git a/common/chat.cpp b/common/chat.cpp index ded8440e66..0cee80434e 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const return text; } -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims) { - if (delims.empty() || prompt.empty()) { - return {}; +common_chat_role common_chat_role_from_string(const std::string & role) { + if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; } + if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; } + if (role == "user") { return COMMON_CHAT_ROLE_USER; } + if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; } + return COMMON_CHAT_ROLE_UNKNOWN; +} + +const char * common_chat_role_to_string(common_chat_role role) { + switch (role) { + case COMMON_CHAT_ROLE_SYSTEM: return "system"; + case COMMON_CHAT_ROLE_ASSISTANT: return "assistant"; + case COMMON_CHAT_ROLE_USER: return "user"; + case COMMON_CHAT_ROLE_TOOL: return "tool"; + case COMMON_CHAT_ROLE_UNKNOWN: return ""; + } + return ""; +} + +json common_chat_msg_delimiters::to_json() const { + json result = json::array(); + for (const auto & d : delimiters) { + result.push_back({ + { "role", common_chat_role_to_string(d.role) }, + { "delimiter", d.delimiter }, + }); + } + return result; +} + +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) { + common_chat_msg_delimiters result; + + if (!delimiters.is_array()) { + return result; } - auto parser = build_peg_parser([&](common_peg_parser_builder & p) { - std::vector all_delims; - std::vector tagged_messages; - - all_delims.reserve(delims.size()); - tagged_messages.reserve(delims.size()); - for (const auto & d : delims) { - all_delims.push_back(d.delimiter); + result.delimiters.reserve(delimiters.size()); + for (const auto & d : delimiters) { + if (!d.is_object()) { + continue; } - - auto any_delim = p.until_one_of(all_delims); - for (const auto & d : delims) { - tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim)); - } - - return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end(); - }); - - common_peg_parse_context ctx(prompt); - const auto result = parser.parse(ctx); - if (!result.success()) { - return {}; + result.delimiters.push_back({ + common_chat_role_from_string(d.value("role", std::string())), + d.value("delimiter", std::string()), + }); } - std::vector spans; - ctx.ast.visit(result, [&](const common_peg_ast_node & node) { - if (!node.tag.empty()) { - spans.push_back({ node.tag, node.start, node.end - node.start }); + return result; +} + +void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) { + for (auto & d : delimiters) { + d.tokens = common_tokenize(vocab, d.delimiter, false, true); + } +} + +common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map & skips) const { + std::vector> matches; + + auto skip = skips.begin(); + for (size_t i = 0; i < tokens.size();) { + if (skip != skips.end() && i == skip->first) { + i += skip->second; + ++skip; + continue; } - }); + for (const auto & d : delimiters) { + if (i + d.tokens.size() > tokens.size()) { + continue; + } + if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) { + matches.emplace_back(d.role, i); + break; + } + } + i++; + } + + matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size()); + + common_chat_msg_spans spans; + for (size_t i = 0; i + 1 < matches.size(); i++) { + const auto & curr = matches[i]; + const auto & next = matches[i + 1]; + spans.add(curr.first, curr.second, next.second - curr.second); + } return spans; } @@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.prompt = prompt; data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.message_spans = common_chat_split_by_role(prompt, { - { "assistant", "<|start|>assistant" }, - { "user", "<|start|>user" }, - { "system", "<|start|>developer" }, - { "system", "<|start|>system" }, - { "tool", "<|start|>functions" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" }, + { COMMON_CHAT_ROLE_USER, "<|start|>user" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" }, + { COMMON_CHAT_ROLE_TOOL, "<|start|>functions" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ data.prompt += data.generation_prompt; } - data.message_spans = common_chat_split_by_role(data.prompt, { - { "user", "<|turn>user\n" }, - { "assistant", "<|turn>model\n" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_USER, "<|turn>user" }, + { COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; @@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t RESULT_START, RESULT_END, }; - // Split the rendered prompt into per-role message spans. Tool results are rendered with the + // Declare per-role message delimiters. Tool results are rendered with the // system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before // the plain "system" one (it is a strict superset, and the role split tries delimiters in order). - data.message_spans = common_chat_split_by_role(data.prompt, { - { "assistant", GEN_PREFIX }, - { "user", TURN_START + USER }, - { "tool", TURN_START + SYSTEM + RESULT_START }, - { "system", TURN_START + SYSTEM }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX }, + { COMMON_CHAT_ROLE_USER, TURN_START + USER }, + { COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START }, + { COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM }, + }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; @@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); - std::vector delimiters; + common_chat_msg_delimiters delimiters; if (!autoparser.assistant_start.empty()) { - delimiters.push_back({ "assistant", autoparser.assistant_start }); + delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start); } if (!autoparser.user_start.empty()) { - delimiters.push_back({ "user", autoparser.user_start }); + delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start); } - if (!delimiters.empty()) { - auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters); - } + auto_params.message_delimiters = std::move(delimiters); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { @@ -2708,5 +2758,9 @@ common_chat_msg common_chat_peg_parse(const common_peg_arena & src_pars std::map common_chat_templates_get_caps(const common_chat_templates * chat_templates) { GGML_ASSERT(chat_templates != nullptr); GGML_ASSERT(chat_templates->template_default != nullptr); + if (chat_templates->template_tool_use != nullptr) { + // take the more expressive template when available + return chat_templates->template_tool_use->caps.to_map(); + } return chat_templates->template_default->caps.to_map(); } diff --git a/common/chat.h b/common/chat.h index 5659cd42a0..7898f1623f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -143,15 +143,75 @@ struct common_chat_msg_diff { } }; +enum common_chat_role { + COMMON_CHAT_ROLE_UNKNOWN, + COMMON_CHAT_ROLE_SYSTEM, + COMMON_CHAT_ROLE_ASSISTANT, + COMMON_CHAT_ROLE_USER, + COMMON_CHAT_ROLE_TOOL +}; + +common_chat_role common_chat_role_from_string(const std::string & role); +const char * common_chat_role_to_string(common_chat_role role); + struct common_chat_msg_span { - std::string role; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; std::size_t pos = 0; std::size_t len = 0; + + bool valid() const { + return role != COMMON_CHAT_ROLE_UNKNOWN; + } +}; + +struct common_chat_msg_spans { + std::vector spans; + + void add(common_chat_role role, size_t pos, size_t len) { + spans.push_back({ role, pos, len }); + } + + bool is_user_start(int32_t pos) const { + for (auto it = spans.begin(); it != spans.end(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) { + return true; + } + } + return false; + } + + int32_t last_user_message_pos() const { + for (auto it = spans.rbegin(); it != spans.rend(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER) { + return (int32_t) it->pos; + } + } + return -1; + } }; struct common_chat_msg_delimiter { - std::string role; - std::string delimiter; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; + std::string delimiter; + llama_tokens tokens = {}; +}; + +struct common_chat_msg_delimiters { + std::vector delimiters; + + common_chat_msg_delimiters() = default; + common_chat_msg_delimiters(std::initializer_list delims) : delimiters(delims) {} + + void add(common_chat_role role, const std::string & delimiter) { + delimiters.push_back({ role, delimiter }); + } + + void tokenize(const llama_vocab * vocab); + + // split tokens into message spans. skips maps a start index to a length of a region to jump over without matching + common_chat_msg_spans split(const llama_tokens & tokens, const std::map & skips = {}) const; + + nlohmann::ordered_json to_json() const; }; struct common_chat_tool { @@ -219,7 +279,7 @@ struct common_chat_params { std::vector preserved_tokens; std::vector additional_stops; std::string parser; - std::vector message_spans; + common_chat_msg_delimiters message_delimiters; }; // per-message parsing syntax @@ -325,5 +385,4 @@ struct common_chat_prompt_preset { common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims); - +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters); diff --git a/common/common.cpp b/common/common.cpp index b01772e1cb..a14e7bbed9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1074,6 +1074,18 @@ std::vector fs_list(const std::string & path, bool include_dir return files; } +std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode) { +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0); + if (!wlen) { return std::ifstream(); } + std::vector wfname(wlen); + (void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen); + return std::ifstream(wfname.data(), mode); +#else + return std::ifstream(fname, mode); +#endif +} + // // TTY utils // @@ -2034,7 +2046,7 @@ bool common_prompt_batch_decode( } size_t common_prompt_checkpoint::size() const { - return data_tgt.size() + data_dft.size(); + return data_tgt.size() + data_dft.size() + data_spec.size(); } bool common_prompt_checkpoint::empty() const { @@ -2049,6 +2061,7 @@ void common_prompt_checkpoint::clear() { data_tgt.clear(); data_dft.clear(); + data_spec.clear(); } void common_prompt_checkpoint::update_pos( @@ -2138,4 +2151,5 @@ void common_prompt_checkpoint::clear_tgt() { void common_prompt_checkpoint::clear_dft() { data_dft.clear(); + data_spec.clear(); } diff --git a/common/common.h b/common/common.h index 040b9cf233..94147d5d8c 100644 --- a/common/common.h +++ b/common/common.h @@ -96,6 +96,7 @@ enum llama_example { LLAMA_EXAMPLE_FIT_PARAMS, LLAMA_EXAMPLE_RESULTS, LLAMA_EXAMPLE_EXPORT_GRAPH_OPS, + LLAMA_EXAMPLE_DOWNLOAD, LLAMA_EXAMPLE_COUNT, }; @@ -290,12 +291,25 @@ struct common_params_sampling { }; struct common_params_model { - std::string path = ""; // model local path // NOLINT - std::string url = ""; // model url to download // NOLINT - std::string hf_repo = ""; // HF repo // NOLINT - std::string hf_file = ""; // HF file // NOLINT - std::string docker_repo = ""; // Docker repo // NOLINT - std::string name = ""; // in format /[:] (tag is optional) // NOLINT + std::string path = ""; // model local path + std::string url = ""; // model url to download + std::string hf_repo = ""; // HF repo + std::string hf_file = ""; // HF file + std::string docker_repo = ""; // Docker repo + + std::string get_name() const { + if (!hf_repo.empty()) { + return hf_repo; + } + if (!docker_repo.empty()) { + return docker_repo; + } + return path; + } + + bool empty() const { + return get_name().empty(); + } }; // draft-model-based speculative decoding parameters @@ -358,12 +372,12 @@ struct common_params_speculative { common_params_speculative_ngram_cache ngram_cache; bool has_dft() const { - return !draft.mparams.path.empty() || !draft.mparams.hf_repo.empty(); + return !draft.mparams.empty(); } uint32_t need_n_rs_seq() const { bool needs_rs_seq = std::any_of(types.begin(), types.end(), [&](auto t) { - return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP; + return t == COMMON_SPECULATIVE_TYPE_DRAFT_MTP || t == COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3; }); return needs_rs_seq ? draft.n_max : 0u; @@ -510,7 +524,6 @@ struct common_params { int32_t control_vector_layer_start = -1; // layer range for control vector int32_t control_vector_layer_end = -1; // layer range for control vector bool offline = false; - bool skip_download = false; // skip model file downloading int32_t ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int32_t ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line @@ -600,7 +613,7 @@ struct common_params { bool cache_prompt = true; // whether to enable prompt caching bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints + int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; @@ -624,12 +637,6 @@ struct common_params { // UI configs bool ui = true; - - // Deprecated: use ui, ui_mcp_proxy, ui_config_json instead - bool webui = ui; - bool webui_mcp_proxy = false; - std::string webui_config_json; - bool ui_mcp_proxy = false; std::string ui_config_json; @@ -848,6 +855,9 @@ struct common_file_info { }; std::vector fs_list(const std::string & path, bool include_directories); +// fs open, also handle UTF8 on Windows +std::ifstream fs_open_ifstream(const std::string & fname, std::ios_base::openmode mode); + // // TTY utils // @@ -1065,6 +1075,10 @@ struct common_prompt_checkpoint { std::vector data_tgt; std::vector data_dft; + // (optional) speculative-decoding implementation state stashed with the checkpoint + // (e.g. eagle3's deferred-boundary g_embd row) + std::vector data_spec; + size_t size() const; bool empty() const; diff --git a/common/download.cpp b/common/download.cpp index f320462753..6b69a44188 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -292,10 +292,6 @@ static int common_download_file_single_online(const std::string & url, const bool file_exists = std::filesystem::exists(path); - if (!file_exists && opts.skip_download) { - return -2; // file is missing and download is disabled - } - if (file_exists && skip_etag) { LOG_DBG("%s: using cached file: %s\n", __func__, path.c_str()); return 304; // 304 Not Modified - fake cached response @@ -362,9 +358,6 @@ static int common_download_file_single_online(const std::string & url, return 304; // 304 Not Modified - fake cached response } // pass this point, the file exists but is different from the server version, so we need to redownload it - if (opts.skip_download) { - return -2; // special code to indicate that the download was skipped due to etag mismatch - } if (remove(path.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); return -1; @@ -691,19 +684,8 @@ static void list_available_gguf_files(const hf_cache::hf_files & files) { } } -struct hf_plan { - hf_cache::hf_file primary; - hf_cache::hf_files model_files; - hf_cache::hf_file mmproj; - hf_cache::hf_file mtp; - hf_cache::hf_file preset; // if set, only this file is downloaded -}; - -static hf_plan get_hf_plan(const common_params_model & model, - const common_download_opts & opts, - bool download_mmproj, - bool download_mtp) { - hf_plan plan; +common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts) { + common_download_hf_plan plan; hf_cache::hf_files all; auto [repo, tag] = common_download_split_repo_tag(model.hf_repo); @@ -752,125 +734,49 @@ static hf_plan get_hf_plan(const common_params_model & model, plan.primary = primary; plan.model_files = get_split_files(all, primary); - if (download_mmproj) { + if (opts.download_mmproj) { plan.mmproj = find_best_mmproj(all, primary.path); } - - if (download_mtp) { + if (opts.download_mtp) { plan.mtp = find_best_mtp(all, primary.path); } return plan; } -struct download_task { - std::string url; - std::string path; -}; - -static std::vector get_url_tasks(const common_params_model & model) { - auto split = get_gguf_split_info(model.url); - - if (split.count <= 1) { - return {{model.url, model.path}}; - } - - auto filename = split.prefix; - if (auto pos = split.prefix.rfind('/'); pos != std::string::npos) { - filename = split.prefix.substr(pos + 1); - } - - auto parent_path = std::filesystem::path(model.path).parent_path(); - auto prefix_path = (parent_path / filename).string(); - - std::vector tasks; - for (int i = 1; i <= split.count; i++) { - auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count); - tasks.push_back({split.prefix + suffix, prefix_path + suffix}); - } - return tasks; -} - -common_download_model_result common_download_model(const common_params_model & model, - const common_download_opts & opts) { - common_download_model_result result; - std::vector tasks; - hf_plan hf; - - bool download_mmproj = opts.download_mmproj; - bool download_mtp = opts.download_mtp; - bool is_hf = !model.hf_repo.empty(); - - if (is_hf) { - hf = get_hf_plan(model, opts, download_mmproj, download_mtp); - if (!hf.preset.path.empty()) { - // if preset.ini exists, only download that file alone - tasks.push_back({hf.preset.url, hf.preset.local_path}); - } else { - for (const auto & f : hf.model_files) { - tasks.push_back({f.url, f.local_path}); - } - if (!hf.mmproj.path.empty()) { - tasks.push_back({hf.mmproj.url, hf.mmproj.local_path}); - } - if (!hf.mtp.path.empty()) { - tasks.push_back({hf.mtp.url, hf.mtp.local_path}); - } - } - } else if (!model.url.empty()) { - tasks = get_url_tasks(model); - } else { - result.model_path = model.path; - return result; - } - - if (tasks.empty()) { - return result; - } - +void common_download_run_tasks(const std::vector & tasks) { std::vector> futures; for (const auto & task : tasks) { futures.push_back(std::async(std::launch::async, - [&task, &opts, is_hf]() { - return common_download_file_single(task.url, task.path, opts, is_hf); + [&task]() { + return common_download_file_single(task.url, task.local_path, task.opts, task.is_hf); } )); } - for (auto & f : futures) { - int status = f.get(); - if (status == -2 && opts.skip_download) { - throw common_skip_download_exception(); - } + for (size_t i = 0; i < futures.size(); ++i) { + std::string url = tasks[i].url; + int status = futures[i].get(); bool is_ok = is_http_status_ok(status); if (!is_ok) { - return {}; + throw std::runtime_error(string_format("Download '%s' failed with status code: %d", url.c_str(), status)); } } +} - if (is_hf) { - if (!hf.preset.path.empty()) { - // if preset.ini is used, do not set other paths - result.preset_path = hf_cache::finalize_file(hf.preset); - } else { - for (const auto & f : hf.model_files) { - hf_cache::finalize_file(f); - } - result.model_path = hf.primary.final_path; +std::vector common_download_get_all_parts(const std::string & url) { + auto split = get_gguf_split_info(url); - if (!hf.mmproj.path.empty()) { - result.mmproj_path = hf_cache::finalize_file(hf.mmproj); - } - - if (!hf.mtp.path.empty()) { - result.mtp_path = hf_cache::finalize_file(hf.mtp); - } - } - } else { - result.model_path = model.path; + if (split.count <= 1) { + return {url}; } - return result; + std::vector parts; + for (int i = 1; i <= split.count; i++) { + auto suffix = string_format("-%05d-of-%05d.gguf", i, split.count); + parts.push_back(split.prefix + suffix); + } + return parts; } // diff --git a/common/download.h b/common/download.h index 8dbf07836f..816e1c7f58 100644 --- a/common/download.h +++ b/common/download.h @@ -1,7 +1,10 @@ #pragma once +#include "hf-cache.h" + #include #include +#include struct common_params_model; @@ -47,66 +50,40 @@ struct common_cached_model_info { } }; -// Options for common_download_model and common_download_file_single +// Options for common_download_file_single struct common_download_opts { std::string bearer_token; common_header_list headers; bool offline = false; - bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid bool download_mmproj = false; bool download_mtp = false; common_download_callback * callback = nullptr; }; -// Result of common_download_model -struct common_download_model_result { - std::string model_path; - std::string mmproj_path; - std::string mtp_path; - std::string preset_path; +struct common_download_task { + common_download_opts opts; + std::string url; + std::string local_path; + std::function on_done; + bool is_hf = false; + + common_download_task() = default; + common_download_task(hf_cache::hf_file f, + const common_download_opts & opts, + std::function on_done = nullptr) + : opts(opts), url(f.url), local_path(f.local_path), on_done(on_done), is_hf(true) {} }; -// throw if the file is missing or invalid (e.g. ETag check failed) -struct common_skip_download_exception : public std::runtime_error { - common_skip_download_exception() : std::runtime_error("skip download") {} -}; +void common_download_run_tasks(const std::vector & tasks); -// Download model from HuggingFace repo or URL -// -// input (via model struct): -// - model.hf_repo: HF repo with optional tag, see common_download_split_repo_tag -// - model.hf_file: specific file in the repo (requires hf_repo) -// - model.url: simple download (used if hf_repo is empty) -// - model.path: local file path -// -// tag matching (for HF repos without model.hf_file): -// - if tag is specified, searches for GGUF matching that quantization -// - if no tag, searches for Q4_K_M, then Q4_0, then first available GGUF -// -// split GGUF: multi-part files like "model-00001-of-00003.gguf" are automatically -// detected and all parts are downloaded -// -// caching: -// - HF repos: uses HuggingFace cache -// - URLs: uses ETag-based caching -// -// when opts.offline=true, no network requests are made -// when download_mmproj=true, searches for mmproj in same directory as model or any parent directory -// then with the closest quantization bits -// when download_mtp=true, applies the same sibling search for an MTP-head GGUF -// -// returns result with model_path, mmproj_path and mtp_path (empty when not found / on failure) -common_download_model_result common_download_model( - const common_params_model & model, - const common_download_opts & opts = {} -); +// if url is a multi-part GGUF file, returns all parts, otherwise returns the single file +std::vector common_download_get_all_parts(const std::string & url); // returns list of cached models std::vector common_list_cached_models(); // download single file from url to local path // returns status code or -1 on error -// returns -2 if the download was skipped due to ETag mismatch (file outdated, skip_download=true) // skip_etag: if true, don't read/write .etag files (for HF cache where filename is the hash) int common_download_file_single(const std::string & url, const std::string & path, @@ -123,3 +100,12 @@ std::string common_docker_resolve_model(const std::string & docker); // - if tag is present, removes only files matching that tag (and orphaned blobs) // returns true if anything was removed bool common_download_remove(const std::string & hf_repo_with_tag); + +struct common_download_hf_plan { + hf_cache::hf_file primary; + hf_cache::hf_files model_files; + hf_cache::hf_file mmproj; + hf_cache::hf_file mtp; + hf_cache::hf_file preset; // if set, only this file is downloaded +}; +common_download_hf_plan common_download_get_hf_plan(const common_params_model & model, const common_download_opts & opts); diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 1fae7884e1..f98cb0876f 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -686,59 +686,62 @@ value set_statement::execute_impl(context & ctx) { return mk_val(); } +static inline void bind_parameters(const std::string & name, const statements & this_args, const func_args & args, context & ctx) { + const size_t expected_count = this_args.size(); + const size_t input_count = args.count(); + + JJ_DEBUG("Invoking '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); + for (size_t i = 0; i < expected_count; ++i) { + if (i < input_count) { + if (is_stmt(this_args[i])) { + // normal parameter + std::string param_name = cast_stmt(this_args[i])->val; + value param_value = args.get_kwarg_or_pos(param_name, i); + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); + ctx.set_val(param_name, param_value); + } else if (is_stmt(this_args[i])) { + // default argument used as normal parameter + auto kwarg = cast_stmt(this_args[i]); + if (!is_stmt(kwarg->key)) { + throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'"); + } + std::string param_name = cast_stmt(kwarg->key)->val; + value param_value = args.get_kwarg_or_pos(param_name, i); + JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); + ctx.set_val(param_name, param_value); + } else { + throw std::runtime_error("Invalid parameter type in '" + name + "'"); + } + } else { + auto & default_arg = this_args[i]; + if (is_stmt(default_arg)) { + auto kwarg = cast_stmt(default_arg); + if (!is_stmt(kwarg->key)) { + throw std::runtime_error("Keyword argument key must be an identifier in '" + name + "'"); + } + std::string param_name = cast_stmt(kwarg->key)->val; + JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); + ctx.set_val(param_name, kwarg->val->execute(args.ctx)); + } else { + throw std::runtime_error("Not enough arguments provided to '" + name + "'"); + } + //std::string param_name = cast_stmt(default_args[i])->val; + //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); + //ctx.var[param_name] = default_args[i]->execute(ctx); + } + } +} + value macro_statement::execute_impl(context & ctx) { if (!is_stmt(this->name)) { throw std::runtime_error("Macro name must be an identifier"); } std::string name = cast_stmt(this->name)->val; - const func_handler func = [this, name, &ctx](const func_args & args) -> value { - size_t expected_count = this->args.size(); - size_t input_count = args.count(); + const func_handler func = [this, name](const func_args & args) -> value { + context macro_ctx(args.ctx); // new scope for macro execution - JJ_DEBUG("Invoking macro '%s' with %zu input arguments (expected %zu)", name.c_str(), input_count, expected_count); - context macro_ctx(ctx); // new scope for macro execution - - // bind parameters - for (size_t i = 0; i < expected_count; ++i) { - if (i < input_count) { - if (is_stmt(this->args[i])) { - // normal parameter - std::string param_name = cast_stmt(this->args[i])->val; - value param_value = args.get_kwarg_or_pos(param_name, i); - JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); - macro_ctx.set_val(param_name, param_value); - } else if (is_stmt(this->args[i])) { - // default argument used as normal parameter - auto kwarg = cast_stmt(this->args[i]); - if (!is_stmt(kwarg->key)) { - throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'"); - } - std::string param_name = cast_stmt(kwarg->key)->val; - value param_value = args.get_kwarg_or_pos(param_name, i); - JJ_DEBUG(" Binding parameter '%s' to argument of type %s", param_name.c_str(), param_value->type().c_str()); - macro_ctx.set_val(param_name, param_value); - } else { - throw std::runtime_error("Invalid parameter type in macro '" + name + "'"); - } - } else { - auto & default_arg = this->args[i]; - if (is_stmt(default_arg)) { - auto kwarg = cast_stmt(default_arg); - if (!is_stmt(kwarg->key)) { - throw std::runtime_error("Keyword argument key must be an identifier in macro '" + name + "'"); - } - std::string param_name = cast_stmt(kwarg->key)->val; - JJ_DEBUG(" Binding parameter '%s' to default argument of type %s", param_name.c_str(), kwarg->val->type().c_str()); - macro_ctx.set_val(param_name, kwarg->val->execute(ctx)); - } else { - throw std::runtime_error("Not enough arguments provided to macro '" + name + "'"); - } - //std::string param_name = cast_stmt(default_args[i])->val; - //JJ_DEBUG(" Binding parameter '%s' to default", param_name.c_str()); - //macro_ctx.var[param_name] = default_args[i]->execute(ctx); - } - } + bind_parameters(name, this->args, args, macro_ctx); // execute macro body JJ_DEBUG("Executing macro '%s' body with %zu statements", name.c_str(), this->body.size()); @@ -752,6 +755,46 @@ value macro_statement::execute_impl(context & ctx) { return mk_val(); } +value call_statement::execute_impl(context & ctx) { + auto call_expr = cast_stmt(this->call); + if (!call_expr) { + throw std::runtime_error("Call statement requires a valid call expression"); + } + + value callee_val = call_expr->callee->execute(ctx); + if (!is_val(callee_val)) { + throw std::runtime_error("Callee is not a function: got " + callee_val->type()); + } + auto * callee_func = cast_val(callee_val); + + context caller_ctx(ctx); // new scope for caller execution + + const func_handler func = [this, caller_ctx = std::move(caller_ctx)](const func_args & args) -> value { + context block_ctx(caller_ctx); // new scope for block execution + + bind_parameters("caller", this->caller_args, args, block_ctx); + + JJ_DEBUG("Executing call body with %zu statements", this->body.size()); + auto res = exec_statements(this->body, block_ctx); + JJ_DEBUG("Call body execution complete, result: %s", res->val_str.str().c_str()); + return res; + }; + + context call_ctx(ctx); + call_ctx.set_val("caller", mk_val("caller", func)); + + func_args args(call_ctx); + + for (const auto & arg_expr : call_expr->args) { + auto arg_val = arg_expr->execute(ctx); + JJ_DEBUG(" Argument type: %s", arg_val->type().c_str()); + args.push_back(arg_val); + } + + JJ_DEBUG("Calling macro '%s' with %zu arguments", callee_func->name.c_str(), args.count()); + return callee_func->invoke(args); +} + value member_expression::execute_impl(context & ctx) { value object = this->object->execute(ctx); diff --git a/common/jinja/runtime.h b/common/jinja/runtime.h index b6f4a6ab48..37b4c35cac 100644 --- a/common/jinja/runtime.h +++ b/common/jinja/runtime.h @@ -552,6 +552,7 @@ struct call_statement : public statement { for (const auto & arg : this->caller_args) chk_type(arg); } std::string type() const override { return "CallStatement"; } + value execute_impl(context & ctx) override; }; struct ternary_expression : public expression { diff --git a/common/json-partial.cpp b/common/json-partial.cpp deleted file mode 100644 index aaf11310ab..0000000000 --- a/common/json-partial.cpp +++ /dev/null @@ -1,324 +0,0 @@ -#include "json-partial.h" - -#include "log.h" - -#include - -#include -#include - -using json = nlohmann::ordered_json; - -enum common_json_stack_element_type { - COMMON_JSON_STACK_ELEMENT_OBJECT, - COMMON_JSON_STACK_ELEMENT_KEY, - COMMON_JSON_STACK_ELEMENT_ARRAY, -}; - -struct common_json_stack_element { - common_json_stack_element_type type; - std::string key; -}; - -bool common_json_parse( - const std::string & input, - const std::string & healing_marker, - common_json & out) -{ - std::string::const_iterator it = input.begin(); - const auto end = input.end(); - return common_json_parse(it, end, healing_marker, out); -} - -bool common_json_parse( - std::string::const_iterator & it, - const std::string::const_iterator & end, - const std::string & healing_marker, - common_json & out) -{ - // // https://json.nlohmann.me/features/parsing/sax_interface/ - struct json_error_locator : public nlohmann::json_sax { - std::size_t position; - bool found_error; - std::string last_token; - std::string exception_message; - std::vector stack; - - json_error_locator() : position(0), found_error(false) {} - - bool parse_error(std::size_t position, const std::string & last_token, const json::exception & ex) override { // NOLINT - this->position = position - 1; - this->found_error = true; - this->last_token = last_token; - this->exception_message = ex.what(); - return false; - } - void close_value() { - if (!stack.empty() && (stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY)) { - stack.pop_back(); - } - } - bool null() override { // NOLINT - close_value(); - return true; - } - bool boolean(bool) override { // NOLINT - close_value(); - return true; - } - bool number_integer(number_integer_t) override { // NOLINT - close_value(); - return true; - } - bool number_unsigned(number_unsigned_t) override { // NOLINT - close_value(); - return true; - } - bool number_float(number_float_t, const string_t &) override { // NOLINT - close_value(); - return true; - } - bool string(string_t &) override { // NOLINT - close_value(); - return true; - } - bool binary(binary_t &) override { // NOLINT - close_value(); - return true; - } - bool start_object(std::size_t) override { // NOLINT - stack.push_back({COMMON_JSON_STACK_ELEMENT_OBJECT, ""}); - return true; - } - bool end_object() override { - GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT); - stack.pop_back(); - close_value(); - return true; - } - bool key(string_t & key) override { // NOLINT - stack.push_back({COMMON_JSON_STACK_ELEMENT_KEY, key}); - return true; - } - bool start_array(std::size_t) override { // NOLINT - stack.push_back({COMMON_JSON_STACK_ELEMENT_ARRAY, ""}); - return true; - } - bool end_array() override { - GGML_ASSERT(!stack.empty() && stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY); - stack.pop_back(); - close_value(); - return true; - } - }; - json_error_locator err_loc; - auto start = it; - json::sax_parse(it, end, &err_loc); - - if (err_loc.found_error) { - it = start; - auto temptative_end = it + err_loc.position; - // LOG_DBG("Error at position %zu (is_end = %s): %s\n", err_loc.position, temptative_end == end ? "true" : "false", err_loc.exception_message.c_str()); - - auto input = std::string(it, temptative_end); - try { - out.json = json::parse(input); - // out.json = json::parse(it, temptative_end); - it = temptative_end; - return true; - } catch (const std::exception & ex) { - // No, needs healing. - LOG_DBG("Failed to parse up to error: %s: <<<%s>>>\n", ex.what(), std::string(it, temptative_end).c_str()); - } - auto can_parse = [](const std::string & str) { - try { - auto _ = json::parse(str); // NOLINT - return true; - } catch (const std::exception &) { - return false; - } - }; - if (!healing_marker.empty() && !err_loc.stack.empty()) { - std::string str(it, temptative_end); - auto last_non_sp_pos = str.find_last_not_of(" \n\r\t"); - if (last_non_sp_pos == std::string::npos) { - throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); - } - auto last_non_sp_char = str[last_non_sp_pos]; - // Used to detect stops on a number, which may not be complete. - auto was_maybe_number = [&]() { - if (!str.empty() && std::isspace(str.back())) { - return false; - } - return std::isdigit(last_non_sp_char) || - last_non_sp_char == '.' || - last_non_sp_char == 'e' || - last_non_sp_char == 'E' || - last_non_sp_char == '-'; - }; - - std::string closing; - for (size_t i = err_loc.stack.size(); i > 0; i--) { - auto & el = err_loc.stack[i - 1]; - if (el.type == COMMON_JSON_STACK_ELEMENT_OBJECT) { - closing += "}"; - } else if (el.type == COMMON_JSON_STACK_ELEMENT_ARRAY) { - closing += "]"; - } else if (el.type != COMMON_JSON_STACK_ELEMENT_KEY) { - throw std::runtime_error("Unexpected stack element type"); - } - } - - // Matches a potentially partial unicode escape sequence, e.g. \u, \uX, \uXX, \uXXX, \uXXXX - static const std::regex partial_unicode_regex(R"(\\u(?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F](?:[0-9a-fA-F])?)?)?)?$)"); - - auto is_high_surrogate = [&](const std::string & s) { - // Check if a partial of a high surrogate (U+D800-U+DBFF) - return s.length() >= 4 && - s[0] == '\\' && s[1] == 'u' && - std::tolower(s[2]) == 'd' && - (s[3] == '8' || s[3] == '9' || std::tolower(s[3]) == 'a' || std::tolower(s[3]) == 'b'); - }; - - // Initialize the unicode marker to a low surrogate to handle the edge case - // where a high surrogate (U+D800-U+DBFF) is immediately followed by a - // backslash (\) - std::string unicode_marker_padding = "udc00"; - std::smatch last_unicode_seq; - - if (std::regex_search(str, last_unicode_seq, partial_unicode_regex)) { - std::smatch second_last_seq; - std::string prelude = str.substr(0, last_unicode_seq.position()); - - // Pad the escape sequence with 0s until it forms a complete sequence of 6 characters - unicode_marker_padding = std::string(6 - last_unicode_seq.length(), '0'); - - if (is_high_surrogate(last_unicode_seq.str())) { - // If the sequence is a partial match for a high surrogate, add a low surrogate (U+DC00-U+UDFF) - unicode_marker_padding += "\\udc00"; - } else if (std::regex_search(prelude, second_last_seq, partial_unicode_regex)) { - if (is_high_surrogate(second_last_seq.str())) { - // If this follows a high surrogate, pad it to be a low surrogate - if (last_unicode_seq.length() == 2) { - unicode_marker_padding = "dc00"; - } else if (last_unicode_seq.length() == 3) { - unicode_marker_padding = "c00"; - } else { - // The original unicode_marker_padding is already padded with 0s - } - } - } - } - - const auto & magic_seed = out.healing_marker.marker = healing_marker;//"$llama.cpp.json$"; - - if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_KEY) { - // We're inside an object value - if (last_non_sp_char == ':' && can_parse(str + "1" + closing)) { - // Was about to create an object value - str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; - } else if (can_parse(str + ": 1" + closing)) { - str += (out.healing_marker.json_dump_marker = ":\"" + magic_seed) + "\"" + closing; - } else if (last_non_sp_char == '{' && can_parse(str + closing)) { - // Was about to create an object - str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; - } else if (can_parse(str + "\"" + closing)) { - // Was inside an object value string - str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; - } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { - // Was inside an object value string after an escape - str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; - } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { - // Was inside an object value string after a partial unicode escape - str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; - } else { - // find last : - auto last_pos = str.find_last_of(':'); - if (last_pos == std::string::npos) { - throw std::runtime_error("Cannot heal a truncated JSON that stopped in an unknown location"); - } - // Cutting back to opening : for object value - str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; - } - } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_ARRAY) { - if ((last_non_sp_char == ',' || last_non_sp_char == '[') && can_parse(str + "1" + closing)) { - // Was about to create an array value - str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; - } else if (can_parse(str + "\"" + closing)) { - // Was inside an array value string - str += (out.healing_marker.json_dump_marker = magic_seed) + "\"" + closing; - } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"" + closing)) { - // Was inside an array value string after an escape - str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\"" + closing; - } else if (can_parse(str + unicode_marker_padding + "\"" + closing)) { - // Was inside an array value string after a partial unicode escape - str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\"" + closing; - } else if (!was_maybe_number() && can_parse(str + ", 1" + closing)) { - // Had just finished a value - str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\"" + closing; - } else { - auto last_pos = str.find_last_of("[,"); - if (last_pos == std::string::npos) { - throw std::runtime_error("Cannot heal a truncated JSON array stopped in an unknown location"); - } - // Cutting back to last [ or , for array value - str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; - } - } else if (err_loc.stack.back().type == COMMON_JSON_STACK_ELEMENT_OBJECT) { - if ((last_non_sp_char == '{' && can_parse(str + closing)) || - (last_non_sp_char == ',' && can_parse(str + "\"\": 1" + closing))) { - // Was about to create an object key+value - str += (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\": 1" + closing; - } else if (!was_maybe_number() && can_parse(str + ",\"\": 1" + closing)) { - // Was about to create an object key+value - str += (out.healing_marker.json_dump_marker = ",\"" + magic_seed) + "\": 1" + closing; - } else if (can_parse(str + "\": 1" + closing)) { - // Was inside an object key string - str += (out.healing_marker.json_dump_marker = magic_seed) + "\": 1" + closing; - } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\": 1" + closing)) { - // Was inside an object key string after an escape - str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\": 1" + closing; - } else if (can_parse(str + unicode_marker_padding + "\": 1" + closing)) { - // Was inside an object key string after a partial unicode escape - str += (out.healing_marker.json_dump_marker = unicode_marker_padding + magic_seed) + "\": 1" + closing; - } else { - auto last_pos = str.find_last_of(':'); - if (last_pos == std::string::npos) { - throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); - } - // fprintf(stderr, "Cutting back to last : for object key+value\n"); - str = str.substr(0, last_pos + 1) + (out.healing_marker.json_dump_marker = "\"" + magic_seed) + "\"" + closing; - } - } else { - throw std::runtime_error("Cannot heal a truncated JSON object stopped in an unknown location"); - } - // fprintf(stderr, "HEALED:\nSTRING <<<\n%s\n>>>\n\nmagic_cut: <<<\n%s\n>>>\n\n", str.c_str(), out.healing_marker.json_dump_marker.c_str()); - out.json = json::parse(str); - it = temptative_end; - return true; - } - // handle unclosed top-level primitive - if (err_loc.position != 0 && !healing_marker.empty() && err_loc.stack.empty()) { - std::string str(it, temptative_end); - const auto & magic_seed = out.healing_marker.marker = healing_marker; - if (can_parse(str + "\"")) { - // Was inside an string - str += (out.healing_marker.json_dump_marker = magic_seed) + "\""; - } else if (str[str.length() - 1] == '\\' && can_parse(str + "\\\"")) { - // Was inside an string after an escape - str += (out.healing_marker.json_dump_marker = "\\" + magic_seed) + "\""; - } else { - // TODO: handle more unclosed top-level primitive if the stack was empty but we got an error (e.g. "tru", "\"", etc...) - // fprintf(stderr, "Closing: TODO\n"); - return false; - } - out.json = json::parse(str); - it = temptative_end; - return true; - } - return false; - } - out.json = json::parse(it, end); - it = end; - return true; -} diff --git a/common/json-partial.h b/common/json-partial.h deleted file mode 100644 index be51aabfbf..0000000000 --- a/common/json-partial.h +++ /dev/null @@ -1,39 +0,0 @@ -#pragma once - -// TODO: use json_fwd.hpp when possible -#include - -// Healing marker (empty if the JSON was fully parsed / wasn't healed). -struct common_healing_marker { - // Raw marker. - std::string marker; - - // Cutting the `common_json.json.dump()` string at the (only) occurrence of this marker should yield the original partial JSON string (modulo spaces / if it had the same dump format). - std::string json_dump_marker; -}; - -// Represents a parsed JSON object, with its optional healing marker (a JSON dump fragment that can be used to find the position of healing in the JSON dump string) -struct common_json { - nlohmann::ordered_json json; - - common_healing_marker healing_marker; -}; - -// Parse the JSON string, healing (closing) any partial JSON if `healing_marker` is not empty. -// -// Healing completes partial JSON strings by adding a (possibly modified) healing marker, then whatever is needed to close the JSON. -// This allows to parse the resulting healed JSON string, yet be able to cut it again if needed at the healing marker. -// (this is used when parsing JSON outputs from the models, then crafting partial JSONs for the partial tool calls in OAI format). -// -// For instance, parsing `{` with a healing marker `foo` will produce a healed JSON `{"foo":1}`, w/ json_dump_marker = `"foo"` (which can be used to break the JSON again). -bool common_json_parse( - const std::string & input, - const std::string & healing_marker, - common_json & out); - -// Parse the JSON string (see overload above), but advancing an iterator to the end of the input when the (potentially partial) parsing succeeds. -bool common_json_parse( - std::string::const_iterator & it, - const std::string::const_iterator & end, - const std::string & healing_marker, - common_json & out); diff --git a/common/json-schema-to-grammar.cpp b/common/json-schema-to-grammar.cpp index e2c4d6ce22..b18607cd65 100644 --- a/common/json-schema-to-grammar.cpp +++ b/common/json-schema-to-grammar.cpp @@ -233,27 +233,27 @@ struct BuiltinRule { }; static std::unordered_map PRIMITIVE_RULES = { - {"boolean", {"(\"true\" | \"false\") space", {}}}, + {"boolean", {"(\"true\" | \"false\")", {}}}, {"decimal-part", {"[0-9]{1,16}", {}}}, {"integral-part", {"[0] | [1-9] [0-9]{0,15}", {}}}, - {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)? space", {"integral-part", "decimal-part"}}}, - {"integer", {"(\"-\"? integral-part) space", {"integral-part"}}}, + {"number", {"(\"-\"? integral-part) (\".\" decimal-part)? ([eE] [-+]? integral-part)?", {"integral-part", "decimal-part"}}}, + {"integer", {"(\"-\"? integral-part)", {"integral-part"}}}, {"value", {"object | array | string | number | boolean | null", {"object", "array", "string", "number", "boolean", "null"}}}, - {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? \"}\" space", {"string", "value"}}}, - {"array", {"\"[\" space ( value (\",\" space value)* )? \"]\" space", {"value"}}}, - {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\" space", {}}}, + {"object", {"\"{\" space ( string \":\" space value (\",\" space string \":\" space value)* )? space \"}\"", {"string", "value"}}}, + {"array", {"\"[\" space ( value (\",\" space value)* )? space \"]\"", {"value"}}}, + {"uuid", {"\"\\\"\" [0-9a-fA-F]{8} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{4} \"-\" [0-9a-fA-F]{12} \"\\\"\"", {}}}, {"char", {"[^\"\\\\\\x7F\\x00-\\x1F] | [\\\\] ([\"\\\\bfnrt] | \"u\" [0-9a-fA-F]{4})", {}}}, - {"string", {"\"\\\"\" char* \"\\\"\" space", {"char"}}}, - {"null", {"\"null\" space", {}}}, + {"string", {"\"\\\"\" char* \"\\\"\"", {"char"}}}, + {"null", {"\"null\"", {}}}, }; static std::unordered_map STRING_FORMAT_RULES = { {"date", {"[0-9]{4} \"-\" ( \"0\" [1-9] | \"1\" [0-2] ) \"-\" ( \"0\" [1-9] | [1-2] [0-9] | \"3\" [0-1] )", {}}}, {"time", {"([01] [0-9] | \"2\" [0-3]) \":\" [0-5] [0-9] \":\" [0-5] [0-9] ( \".\" [0-9]{3} )? ( \"Z\" | ( \"+\" | \"-\" ) ( [01] [0-9] | \"2\" [0-3] ) \":\" [0-5] [0-9] )", {}}}, {"date-time", {"date \"T\" time", {"date", "time"}}}, - {"date-string", {"\"\\\"\" date \"\\\"\" space", {"date"}}}, - {"time-string", {"\"\\\"\" time \"\\\"\" space", {"time"}}}, - {"date-time-string", {"\"\\\"\" date-time \"\\\"\" space", {"date-time"}}} + {"date-string", {"\"\\\"\" date \"\\\"\"", {"date"}}}, + {"time-string", {"\"\\\"\" time \"\\\"\"", {"time"}}}, + {"date-time-string", {"\"\\\"\" date-time \"\\\"\"", {"date-time"}}} }; static bool is_reserved_name(const std::string & name) { @@ -551,16 +551,16 @@ private: } return join_seq(); }; - return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space"); + return _add_rule(name, "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\""); } /* Returns a rule that matches a JSON string that is none of the provided strings not_strings({"a"}) - -> ["] ( [a] char+ | [^"a] char* )? ["] space + -> ["] ( [a] char+ | [^"a] char* )? ["] not_strings({"and", "also"}) - -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] space + -> ["] ( [a] ([l] ([s] ([o] char+ | [^"o] char*) | [^"s] char*) | [n] ([d] char+ | [^"d] char*) | [^"ln] char*) | [^"a] char* )? ["] */ std::string _not_strings(const std::vector & strings) { @@ -619,7 +619,7 @@ private: if (!trie.is_end_of_string) { out << "?"; } - out << " [\"] space"; + out << " [\"]"; return out.str(); } @@ -725,7 +725,7 @@ private: rule += " )?"; } - rule += " \"}\" space"; + rule += " space \"}\""; return rule; } @@ -858,14 +858,14 @@ public: return _add_rule(rule_name, _generate_union_rule(name, schema_types)); } if (schema.contains("const")) { - return _add_rule(rule_name, _generate_constant_rule(schema["const"]) + " space"); + return _add_rule(rule_name, _generate_constant_rule(schema["const"])); } if (schema.contains("enum")) { std::vector enum_values; for (const auto & v : schema["enum"]) { enum_values.push_back(_generate_constant_rule(v)); } - return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_values, " | ") + ")"); } if ((schema_type.is_null() || schema_type == "object") && (schema.contains("properties") || @@ -933,7 +933,7 @@ public: } } if (!enum_intersection.empty()) { - return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space"); + return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ")"); } } return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); @@ -948,7 +948,7 @@ public: } rule += visit(items[i], name + (name.empty() ? "" : "-") + "tuple-" + std::to_string(i)); } - rule += " \"]\" space"; + rule += " space \"]\""; return _add_rule(rule_name, rule); } std::string item_rule_name = visit(items, name + (name.empty() ? "" : "-") + "item"); @@ -956,7 +956,7 @@ public: json max_items_json = schema.contains("maxItems") ? schema["maxItems"] : json(); int max_items = max_items_json.is_number_integer() ? max_items_json.get() : std::numeric_limits::max(); - return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " \"]\" space"); + return _add_rule(rule_name, "\"[\" space " + build_repetition(item_rule_name, min_items, max_items, "\",\" space") + " space \"]\""); } if ((schema_type.is_null() || schema_type == "string") && schema.contains("pattern")) { return _visit_pattern(schema["pattern"], rule_name); @@ -972,7 +972,7 @@ public: std::string char_rule = _add_primitive("char", PRIMITIVE_RULES.at("char")); int min_len = schema.contains("minLength") ? schema["minLength"].get() : 0; int max_len = schema.contains("maxLength") ? schema["maxLength"].get() : std::numeric_limits::max(); - return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\" space"); + return _add_rule(rule_name, "\"\\\"\" " + build_repetition(char_rule, min_len, max_len) + " \"\\\"\""); } if (schema_type == "integer" && (schema.contains("minimum") || schema.contains("exclusiveMinimum") || schema.contains("maximum") || schema.contains("exclusiveMaximum"))) { int64_t min_value = std::numeric_limits::min(); @@ -990,7 +990,7 @@ public: std::stringstream out; out << "("; build_min_max_int(min_value, max_value, out); - out << ") space"; + out << ")"; return _add_rule(rule_name, out.str()); } if (schema.empty() || schema_type == "object") { diff --git a/common/peg-parser.cpp b/common/peg-parser.cpp index d4b491a80e..807e952d90 100644 --- a/common/peg-parser.cpp +++ b/common/peg-parser.cpp @@ -6,13 +6,14 @@ #include "unicode.h" #include +#include #include #include #include #include #include +#include #include -#include // Trick to catch missing branches template @@ -88,40 +89,7 @@ struct trie { return match_result{match_result::NO_MATCH}; } - struct prefix_and_next { - std::vector prefix; - std::vector next_chars; - }; - - std::vector collect_prefix_and_next() { - std::vector prefix; - std::vector result; - collect_prefix_and_next(0, prefix, result); - return result; - } - private: - void collect_prefix_and_next(size_t index, std::vector & prefix, std::vector & out) { - if (!nodes[index].is_word) { - if (!nodes[index].children.empty()) { - std::vector chars; - chars.reserve(nodes[index].children.size()); - for (const auto & p : nodes[index].children) { - chars.push_back(p.first); - } - out.emplace_back(prefix_and_next{prefix, chars}); - } - } - - for (const auto & p : nodes[index].children) { - uint32_t ch = p.first; - auto child = p.second; - prefix.push_back(ch); - collect_prefix_and_next(child, prefix, out); - prefix.pop_back(); - } - } - size_t create_node() { size_t index = nodes.size(); nodes.emplace_back(); @@ -153,6 +121,65 @@ struct trie { } }; +// Aho-Corasick automaton +struct aho_corasick { + trie t; + std::vector fail; // failure links + std::vector order; // states in BFS order + std::vector terminal; // match states (directly or via a suffix link) + std::set alphabet; // every character with a transition + + aho_corasick(const std::vector & strings) : t(strings) { + const auto & nodes = t.nodes; + const size_t n = nodes.size(); + + fail.assign(n, 0); + order.reserve(n); + + std::deque queue{ 0 }; + while (!queue.empty()) { + size_t u = queue.front(); + queue.pop_front(); + order.push_back(u); + for (const auto & [ch, v] : nodes[u].children) { + if (u != 0) { + size_t f = fail[u]; + while (f && nodes[f].children.find(ch) == nodes[f].children.end()) { + f = fail[f]; + } + auto it = nodes[f].children.find(ch); + fail[v] = (it != nodes[f].children.end() && it->second != v) ? it->second : 0; + } + queue.push_back(v); + } + } + + terminal.assign(n, false); + for (size_t u : order) { + terminal[u] = nodes[u].is_word || (u != 0 && terminal[fail[u]]); + } + + for (const auto & node : nodes) { + for (const auto & [ch, v] : node.children) { + alphabet.insert(ch); + } + } + } + + size_t num_states() const { return t.nodes.size(); } + bool is_terminal(size_t s) const { return terminal[s]; } + + // follow failure links until a transition on `ch` exists. + size_t next(size_t state, uint32_t ch) const { + const auto & nodes = t.nodes; + while (state && nodes[state].children.find(ch) == nodes[state].children.end()) { + state = fail[state]; + } + auto it = nodes[state].children.find(ch); + return it != nodes[state].children.end() ? it->second : 0; + } +}; + static std::pair parse_hex_escape(const std::string & str, size_t pos, int hex_count) { if (pos + hex_count > str.length()) { return {0, 0}; @@ -894,6 +921,10 @@ struct parser_executor { common_peg_parse_result operator()(const common_peg_gbnf_parser & p) { return arena.parse(p.child, ctx, start_pos); } + + common_peg_parse_result operator()(const common_peg_ac_parser & p) { + return arena.parse(p.child, ctx, start_pos); + } }; common_peg_parse_result common_peg_arena::parse(common_peg_parse_context & ctx, size_t start) const { @@ -962,7 +993,8 @@ void common_peg_arena::resolve_refs() { std::is_same_v || std::is_same_v || std::is_same_v || - std::is_same_v) { + std::is_same_v || + std::is_same_v) { p.child = resolve_ref(p.child); } else if constexpr (std::is_same_v) { p.child = resolve_ref(p.child); @@ -992,12 +1024,12 @@ void common_peg_arena::resolve_refs() { } std::string common_peg_arena::dump(common_peg_parser_id id) const { - std::unordered_set visited; + std::set visited; return dump_impl(id, visited); } std::string common_peg_arena::dump_impl(common_peg_parser_id id, - std::unordered_set & visited) const { + std::set & visited) const { // Check for cycles if (visited.count(id)) { return "[cycle]"; @@ -1043,6 +1075,8 @@ std::string common_peg_arena::dump_impl(common_peg_parser_id return "Atomic(" + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Gbnf(" + p.grammar + ", " + dump_impl(p.child, visited) + ")"; + } else if constexpr (std::is_same_v) { + return "Ac(" + string_join(p.delimiters, " | ") + ", " + dump_impl(p.child, visited) + ")"; } else if constexpr (std::is_same_v) { return "Any"; } else if constexpr (std::is_same_v) { @@ -1342,7 +1376,7 @@ common_peg_parser common_peg_parser_builder::json_object() { common_peg_parser common_peg_parser_builder::json_array() { return rule("json-array", [this]() { auto ws = space(); - auto elements = sequence({json(), zero_or_more(sequence({literal(","), ws, json()}))}); + auto elements = sequence({json(), zero_or_more(sequence({ws, literal(","), ws, json()}))}); return sequence({ literal("["), ws, @@ -1452,6 +1486,13 @@ common_peg_parser common_peg_parser_builder::json_member(const std::string & key }); } +common_peg_parser common_peg_parser_builder::ac(const common_peg_parser & p, const std::vector & delimiters) { + if (delimiters.empty()) { + throw std::runtime_error("ac parser requires at least one delimiter"); + } + return add(common_peg_ac_parser{p, delimiters}); +} + static std::string gbnf_escape_char_class(uint32_t c) { if (c == '-' || c == ']' || c == '[' || c == '\\') { return "\\" + std::string(1, (char) c); @@ -1502,61 +1543,118 @@ static std::string gbnf_escape_char_class(uint32_t c) { return std::string(buf); } -static std::string gbnf_excluding_pattern(const std::vector & strings) { - trie matcher(strings); - auto pieces = matcher.collect_prefix_and_next(); - - std::string pattern; - std::string trailing; // optional proper-prefix of a delimiter, allowed only at the very end - for (size_t i = 0; i < pieces.size(); ++i) { - if (i > 0) { - pattern += " | "; - } - - const auto & pre = pieces[i].prefix; - const auto & chars = pieces[i].next_chars; - - std::string cls; - cls.reserve(chars.size()); - for (uint32_t ch : chars) { - cls += gbnf_escape_char_class(ch); - } - - if (!pre.empty()) { - std::string pre_literal = gbnf_format_literal(common_unicode_cpts_to_utf8(pre)); - pattern += pre_literal + " [^" + cls + "]"; - // Each interior alternative consumes a delimiter-prefix plus a disambiguating - // char, so the repetition alone cannot match a value that *ends* on a proper - // prefix of a delimiter (e.g. a trailing "\n" when the delimiter is - // "\n\n"). The runtime until() (greedy first-match) accepts such - // values, so without this the grammar would reject input the parser accepts. - // Allow the value to terminate on any proper prefix as an optional tail. - // This makes the grammar a slight superset of the runtime language (a value - // may end on the longest prefix, which greedy first-match would not itself - // produce); harmless for constrained generation, which only needs to admit - // every runtime-valid string. - if (!trailing.empty()) { - trailing += " | "; - } - trailing += pre_literal; - } else { - pattern += "[^" + cls + "]"; - } +static std::string gbnf_char_class(const std::vector & chars, bool negate) { + std::string s = negate ? "[^" : "["; + for (uint32_t ch : chars) { + s += gbnf_escape_char_class(ch); } - - std::string result = "(" + pattern + ")*"; - if (!trailing.empty()) { - result += " (" + trailing + ")?"; - } - return result; + return s + "]"; } -static std::unordered_set collect_reachable_rules( +static std::string gbnf_ac_grammar( + const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings, + const std::function &, + const std::map> &, + const std::vector &, + const std::function &)> & build_rule) { + aho_corasick ac(strings); + + auto state_name = [&](size_t s) -> std::string { + if (s == 0) { + return prefix; + } + std::string num = std::to_string(s); + num = num.size() == 1 ? ("0" + num) : num; + return prefix + "-" + num; + }; + + for (size_t q = 0; q < ac.num_states(); q++) { + if (ac.is_terminal(q)) { + continue; // match states + } + + std::map> buckets; + std::vector completing; // chars that complete a delimiter + std::vector specific; // chars with an explicit transition + for (uint32_t c : ac.alphabet) { + size_t d = ac.next(q, c); + if (ac.is_terminal(d)) { + completing.push_back(c); + specific.push_back(c); + } else if (d != 0) { + buckets[d].push_back(c); // specific non-root destination + specific.push_back(c); + } + } + + builder.add_rule(state_name(q), build_rule(completing, buckets, specific, state_name)); + } + + // An empty delimiter makes the start state terminal. Emit an entry rule + // that matches the empty string so the returned reference stays valid. + if (ac.is_terminal(0)) { + builder.add_rule(prefix, "|"); + } + + return state_name(0); +} + +// GBNF grammar matching strings that contain no string in `strings` as a +// substring. Emits the complement of an Aho-Corasick automaton DFA and returns +// the start state rule name. +// +// ref: https://github.com/ggml-org/llama.cpp/pull/24839 +static std::string gbnf_excluding_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + return gbnf_ac_grammar(builder, prefix, strings, + [](const std::vector & /*completing*/, + const std::map> & buckets, + const std::vector & specific, + const std::function & state_name) { + // every state is accepting and completing chars get no + // alternative, so a forbidden string can never be matched + std::string rhs = "|"; + for (const auto & [d, chars] : buckets) { + rhs += " " + gbnf_char_class(chars, false) + " " + state_name(d) + " |"; + } + rhs += " " + gbnf_char_class(specific, true) + " " + state_name(0); + return rhs; + }); +} + +// GBNF grammar matching everything up to and including the first occurrence of +// any string in `strings`. Emits the Aho-Corasick automaton DFA and returns +// the start state rule name. +static std::string gbnf_including_grammar(const common_grammar_builder & builder, + const std::string & prefix, + const std::vector & strings) { + return gbnf_ac_grammar(builder, prefix, strings, + [](const std::vector & completing, + const std::map> & buckets, + const std::vector & specific, + const std::function & state_name) { + std::vector alts; + if (!completing.empty()) { + alts.push_back(gbnf_char_class(completing, false)); // terminate on match + } + for (const auto & [d, chars] : buckets) { + alts.push_back(gbnf_char_class(chars, false) + " " + state_name(d)); + } + // every other character keeps scanning from the start state + alts.push_back(gbnf_char_class(specific, true) + " " + state_name(0)); + return string_join(alts, " | "); + }); +} + +static std::set collect_reachable_rules( const common_peg_arena & arena, const common_peg_parser_id & rule ) { - std::unordered_set reachable; - std::unordered_set visited; + std::set reachable; + std::set visited; std::function visit = [&](common_peg_parser_id id) { const auto & parser = arena.get(id); @@ -1588,6 +1686,7 @@ static std::unordered_set collect_reachable_rules( std::is_same_v || std::is_same_v || std::is_same_v || + std::is_same_v || std::is_same_v) { visit(p.child); } else if constexpr (std::is_same_v) { @@ -1765,7 +1864,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo if (p.delimiters.empty()) { return ".*"; } - return gbnf_excluding_pattern(p.delimiters); + return gbnf_excluding_grammar(builder, "until-" + std::to_string(id), p.delimiters); } else if constexpr (std::is_same_v) { if (schema_delegates(p)) { return to_gbnf(p.child); @@ -1782,6 +1881,8 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo return to_gbnf(p.child); } else if constexpr (std::is_same_v) { return p.grammar; + } else if constexpr (std::is_same_v) { + return gbnf_including_grammar(builder, "ac-" + std::to_string(id), p.delimiters); } else { static_assert(is_always_false_v); } @@ -1789,7 +1890,7 @@ void common_peg_arena::build_grammar(const common_grammar_builder & builder, boo }; // Collect reachable rules - std::unordered_set reachable_rules; + std::set reachable_rules; if (lazy) { // Collect rules reachable from trigger rules @@ -1918,6 +2019,8 @@ static nlohmann::json serialize_parser_variant(const common_peg_parser_variant & }; } else if constexpr (std::is_same_v) { return json{{"type", "gbnf"}, {"child", p.child}, {"grammar", p.grammar}}; + } else if constexpr (std::is_same_v) { + return json{{"type", "ac"}, {"child", p.child}, {"delimiters", p.delimiters}}; } }, variant); } @@ -2090,6 +2193,16 @@ static common_peg_parser_variant deserialize_parser_variant(const nlohmann::json }; } + if (type == "ac") { + if (!j.contains("child") || !j.contains("delimiters") || !j["delimiters"].is_array() || j["delimiters"].empty()) { + throw std::runtime_error("ac parser requires 'child' and a non-empty 'delimiters' array"); + } + return common_peg_ac_parser{ + j["child"].get(), + j["delimiters"].get>(), + }; + } + throw std::runtime_error("Unknown parser type: " + type); } diff --git a/common/peg-parser.h b/common/peg-parser.h index b6bb05214b..c198499dd9 100644 --- a/common/peg-parser.h +++ b/common/peg-parser.h @@ -3,8 +3,8 @@ #include #include +#include #include -#include #include #include #include @@ -275,6 +275,11 @@ struct common_peg_gbnf_parser { std::string grammar; }; +struct common_peg_ac_parser { + common_peg_parser_id child; + std::vector delimiters; +}; + // Variant holding all parser types using common_peg_parser_variant = std::variant< common_peg_epsilon_parser, @@ -296,7 +301,8 @@ using common_peg_parser_variant = std::variant< common_peg_ref_parser, common_peg_atomic_parser, common_peg_tag_parser, - common_peg_gbnf_parser + common_peg_gbnf_parser, + common_peg_ac_parser >; class common_peg_arena { @@ -335,7 +341,7 @@ class common_peg_arena { friend class common_peg_parser_builder; private: - std::string dump_impl(common_peg_parser_id id, std::unordered_set & visited) const; + std::string dump_impl(common_peg_parser_id id, std::set & visited) const; common_peg_parser_id add_parser(common_peg_parser_variant parser); void add_rule(const std::string & name, common_peg_parser_id id); @@ -514,6 +520,13 @@ class common_peg_parser_builder { // the child's grammar. Parsing delegates entirely to the child. common_peg_parser gbnf(const common_peg_parser & p, const std::string & grammar) { return add(common_peg_gbnf_parser{p, grammar}); } + // Wraps a child parser but emits a GBNF grammar built from the Aho-Corasick + // automaton of `delimiters`, matching everything up to and including the + // first delimiter. Parsing delegates entirely to the child, which is + // responsible for consuming the delimiter (e.g. until(D) + literal(D)). + common_peg_parser ac(const common_peg_parser & p, const std::vector & delimiters); + common_peg_parser ac(const common_peg_parser & p, const std::string & delimiter) { return ac(p, std::vector{delimiter}); } + void set_root(const common_peg_parser & p); common_peg_arena build(); diff --git a/common/speculative.cpp b/common/speculative.cpp index 6f387f2cfc..c922a3f592 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -161,6 +161,10 @@ struct common_speculative_impl { virtual void accept(llama_seq_id seq_id, uint16_t n_accepted, bool is_other) = 0; + // (optional) serialize/restore per-seq internal state (e.g. eagle3's deferred boundary). + virtual bool get_state(llama_seq_id /*seq_id*/, std::vector & /*data*/) const { return false; } + virtual void set_state(llama_seq_id /*seq_id*/, const std::vector & /*data*/) {} + // true if this implementation requires the target context to extract post-norm embeddings virtual bool need_embd() const = 0; @@ -841,6 +845,49 @@ struct common_speculative_impl_draft_eagle3 : public common_speculative_impl { (size_t) n_embd_dec * sizeof(float)); } + // we only need to stash the deferred boundary's g_embd row for recurrent/hybrid targets: + // their single-position checkpoints drop it on restore + bool need_boundary_stash() const { + const llama_model * model_tgt = llama_get_model(params.ctx_tgt); + return llama_model_is_recurrent(model_tgt) || llama_model_is_hybrid(model_tgt); + } + + bool get_state(llama_seq_id seq_id, std::vector & data) const override { + if (!need_boundary_stash()) { + return false; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq || pending_pos_last[seq_id] < 0) { + return false; + } + + const llama_pos pos = pending_pos_last[seq_id]; + const std::vector & g = pending_g_last[seq_id]; + + data.resize(sizeof(llama_pos) + g.size() * sizeof(float)); + std::memcpy(data.data(), &pos, sizeof(llama_pos)); + std::memcpy(data.data() + sizeof(llama_pos), g.data(), g.size() * sizeof(float)); + return true; + } + + void set_state(llama_seq_id seq_id, const std::vector & data) override { + if (!need_boundary_stash()) { + return; + } + if (seq_id < 0 || seq_id >= (llama_seq_id) n_seq) { + return; + } + if (data.size() != sizeof(llama_pos) + (size_t) n_embd_dec * sizeof(float)) { + return; + } + + llama_pos pos = -1; + std::memcpy(&pos, data.data(), sizeof(llama_pos)); + + pending_pos_last[seq_id] = pos; + pending_g_last[seq_id].resize(n_embd_dec); + std::memcpy(pending_g_last[seq_id].data(), data.data() + sizeof(llama_pos), (size_t) n_embd_dec * sizeof(float)); + } + bool need_embd() const override { return false; } @@ -858,7 +905,13 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int32_t n_embd = 0; - bool is_mem_shared = false; + // One MTP draft driver, three modes (set once in the ctor): + // is_mem_shared (gemma4): shares the target KV, runs all heads in one graph. + // chain_heads (step35): n_mtp_layers trained heads, one per draft step. + // neither (qwen35 / qwen35moe): a single trained MTP head. + int32_t n_mtp_layers = 1; + bool is_mem_shared = false; // gemma4 + bool chain_heads = false; // derived in the ctor: n_mtp_layers > 1 && !is_mem_shared // Per-sequence cross-batch carryover: pair (h_p, x_{p+1}) at MTP pos p+1. // The last h-row of one process() call needs the first token of the NEXT @@ -873,10 +926,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { std::vector> verify_h; std::vector verify_h_rows; - // Per-seq draft length from the last draft() call, used in accept() to - // roll back ctx_dft's recurrent state past the AR draft's redundant - // pre-advancement before process() mirrored the verify batch. - std::vector last_n_drafted; + std::vector i_last; + std::vector> chain_h; common_speculative_impl_draft_mtp(const common_params_speculative & params, uint32_t n_seq) : common_speculative_impl(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, n_seq) @@ -889,6 +940,7 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { n_embd = llama_model_n_embd_out(llama_get_model(ctx_dft)); GGML_ASSERT(n_embd == llama_model_n_embd(llama_get_model(ctx_tgt)) && "MTP input row width must match the target h_nextn width"); + n_mtp_layers = std::max(1, (int) llama_model_n_layer_nextn(llama_get_model(ctx_dft))); LOG_INF("%s: adding speculative implementation 'draft-mtp'\n", __func__); LOG_INF("%s: - n_max=%d, n_min=%d, p_min=%.2f, n_embd=%d, backend_sampling=%d\n", __func__, this->params.n_max, this->params.n_min, this->params.p_min, n_embd, (int) this->params.backend_sampling); @@ -935,16 +987,25 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { llama_set_embeddings_nextn(ctx_dft, true, /*masked*/ true); is_mem_shared = llama_get_ctx_other(ctx_dft) == ctx_tgt; + chain_heads = n_mtp_layers > 1 && !is_mem_shared; + + if (chain_heads) { + this->params.n_max = std::min(this->params.n_max, n_mtp_layers); + + chain_h.assign(n_seq, {}); + for (auto & c : chain_h) { + c.reserve((size_t) (this->params.n_max + 1) * n_embd); + } + } pending_h.assign(n_seq, std::vector(n_embd, 0.0f)); + i_last.assign(n_seq, -1); i_batch_beg.assign(n_seq, -1); i_batch_end.assign(n_seq, -1); verify_h.assign(n_seq, {}); verify_h_rows.assign(n_seq, 0); - - last_n_drafted.assign(n_seq, 0); } ~common_speculative_impl_draft_mtp() override { @@ -1050,9 +1111,34 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { set_h(i_batch_beg[seq_id], pending_h[seq_id].data()); } - const int32_t rc = llama_decode(ctx_dft, batch); - if (rc != 0) { - LOG_ERR("%s: llama_decode(ctx_dft) failed rc=%d (pos=%d)\n", __func__, (int) rc, (int) batch_in.pos[0]); + auto * mem_dft = llama_get_memory(ctx_dft); + + bool ok = true; + for (int head = 0; head < n_mtp_layers; ++head) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340/changes#r3413498544 + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (i_batch_beg[seq_id] < 0) { + continue; + } + llama_memory_seq_rm(mem_dft, seq_id, batch_in.pos[i_batch_beg[seq_id]], -1); + } + llama_set_nextn_layer_offset(ctx_dft, head); + } + + const int32_t rc = llama_decode(ctx_dft, batch); + if (rc != 0) { + LOG_ERR("%s: llama_decode(ctx_dft) head=%d failed rc=%d (pos=%d)\n", + __func__, head, (int) rc, (int) batch_in.pos[0]); + ok = false; + break; + } + } + + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + if (!ok) { return false; } } @@ -1087,7 +1173,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { int n_drafting = 0; std::vector drafting(n_seq); - const float * h_row = nullptr; const size_t row_bytes = (size_t) n_embd * sizeof(float); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1102,22 +1187,43 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { common_sampler_reset(smpls[seq_id].get()); common_batch_add(batch, dp.id_last, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, pending_h[seq_id].data(), row_bytes); - h_row = pending_h[seq_id].data(); - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); - } + i_last[seq_id] = batch.n_tokens - 1; - int ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode returned %d\n", __func__, ret); - return; + if (chain_heads) { + chain_h[seq_id].assign(pending_h[seq_id].begin(), pending_h[seq_id].end()); + } } int i = 0; while (n_drafting > 0) { - int i_batch = 0; + // each step decodes under a different head, i.e. a different decoder layer, and + // KV is per layer. process() filled this layer's KV only for positions < n_past + // (prompt + accepted prefix) โ€” nothing in the draft region yet. so reset the + // draft region (the seq_rm lower bound is n_past, leaving the prompt KV intact) + // and select head i so it rebuilds its own layer's KV there; decoding just the + // latest token would leave its attention reading cells only another head wrote. + if (chain_heads) { + auto * mem_dft = llama_get_memory(ctx_dft); + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { + if (drafting[seq_id]) { + llama_memory_seq_rm(mem_dft, seq_id, dparams[seq_id].n_past, -1); + } + } + llama_set_nextn_layer_offset(ctx_dft, i); + } + int ret = llama_decode(ctx_dft, batch); + if (ret != 0) { + LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); + break; + } + + // rebuild the batch for the next step: the growing-KV paths re-add only the + // new token (the KV already holds the prefix), while chained heads re-add the + // whole prefix at the next head. dropped sequences are simply not re-added. common_batch_clear(batch); for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { @@ -1127,9 +1233,8 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { auto * smpl = smpls[seq_id].get(); - common_sampler_sample(smpl, ctx_dft, i_batch, true); - h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_batch); - ++i_batch; + common_sampler_sample(smpl, ctx_dft, i_last[seq_id], true); + const float * h_row = llama_get_embeddings_nextn_ith(ctx_dft, i_last[seq_id]); const auto * cur_p = common_sampler_get_candidates(smpl, true); @@ -1163,30 +1268,41 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { continue; } - if (is_mem_shared) { + if (chain_heads) { + // ref: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448031546 + chain_h[seq_id].insert(chain_h[seq_id].end(), h_row, h_row + n_embd); + + const int n_rows = (int) result.size() + 1; // id_last + tokens drafted so far + for (int t = 0; t < n_rows; ++t) { + const llama_token tok = (t == 0) ? dp.id_last : result[t - 1]; + common_batch_add(batch, tok, dp.n_past + t, { seq_id }, t == n_rows - 1); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, + chain_h[seq_id].data() + (size_t) t * n_embd, row_bytes); + } + } else if (is_mem_shared) { // note: with shared memory (e.g. Gemma4 assistants) we use the same position for all draft tokens // ref: https://github.com/huggingface/transformers/blob/effde20942e3f82a1b97449f60b3a48c5ff96145/docs/source/en/model_doc/gemma4_assistant.md?plain=1#L36-L37 common_batch_add(batch, id, dp.n_past, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } else { common_batch_add(batch, id, dp.n_past + i + 1, { seq_id }, true); + std::memcpy(batch.embd + (size_t) (batch.n_tokens - 1) * n_embd, h_row, row_bytes); } - std::memcpy(batch.embd + n_embd*(batch.n_tokens - 1), h_row, row_bytes); + + i_last[seq_id] = batch.n_tokens - 1; } if (batch.n_tokens == 0) { break; } - // evaluate the drafted tokens on the draft model - ret = llama_decode(ctx_dft, batch); - if (ret != 0) { - LOG_WRN("%s: llama_decode[%d] returned %d\n", __func__, i, ret); - break; - } - ++i; } + if (chain_heads) { + llama_set_nextn_layer_offset(ctx_dft, 0); // restore default for non-draft decodes + } + for (llama_seq_id seq_id = 0; seq_id < (llama_seq_id) n_seq; ++seq_id) { auto & dp = dparams[seq_id]; if (!dp.drafting) { @@ -1196,8 +1312,6 @@ struct common_speculative_impl_draft_mtp : public common_speculative_impl { if (dp.result->size() < (size_t) params.n_min) { dp.result->clear(); } - - last_n_drafted[seq_id] = (uint16_t) dp.result->size(); } } @@ -1810,7 +1924,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, bool has_draft_simple = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_SIMPLE)); bool has_draft_eagle3 = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3)) && params.draft.ctx_dft != nullptr; - bool has_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; + bool has_draft_mtp = (enabled_configs & (1u << COMMON_SPECULATIVE_TYPE_DRAFT_MTP)) && params.draft.ctx_dft != nullptr; @@ -1848,7 +1962,7 @@ common_speculative * common_speculative_init(common_params_speculative & params, if (has_draft_eagle3) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_EAGLE3, params)); } - if (has_mtp) { + if (has_draft_mtp) { configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT_MTP, params)); } } @@ -2118,6 +2232,31 @@ void common_speculative_accept(common_speculative * spec, llama_seq_id seq_id, u } } +// TODO: support the case of more than one speculative implementations having a state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data) { + if (spec == nullptr) { + return false; + } + + for (auto & impl : spec->impls) { + if (impl->get_state(seq_id, data)) { + return true; + } + } + + return false; +} + +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->set_state(seq_id, data); + } +} + void common_speculative_print_stats(const common_speculative * spec) { if (spec == nullptr) { return; diff --git a/common/speculative.h b/common/speculative.h index bf76ad709e..c58fac3cc6 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -68,6 +68,10 @@ void common_speculative_draft(common_speculative * spec); // informs the speculative context that n_accepted tokens were accepted by the target model void common_speculative_accept(common_speculative * spec, llama_seq_id, uint16_t n_accepted); +// (optional) get/set internal state +bool common_speculative_get_state(common_speculative * spec, llama_seq_id seq_id, std::vector & data); +void common_speculative_set_state(common_speculative * spec, llama_seq_id seq_id, const std::vector & data); + // print statistics about the speculative decoding void common_speculative_print_stats(const common_speculative * spec); diff --git a/conversion/__init__.py b/conversion/__init__.py index 00192cf33a..5aad203e53 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -46,6 +46,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "DbrxForCausalLM": "dbrx", "DeciLMForCausalLM": "deci", "DeepseekForCausalLM": "deepseek", + "DeepseekOCRForCausalLM": "deepseek", "DeepseekV2ForCausalLM": "deepseek", "DeepseekV3ForCausalLM": "deepseek", "DeepseekV32ForCausalLM": "deepseek", @@ -96,6 +97,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "GraniteMoeHybridForCausalLM": "granite", "GraniteMoeSharedForCausalLM": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "Grok1ForCausalLM": "grok", "GrokForCausalLM": "grok", "GroveMoeForCausalLM": "grovemoe", @@ -123,6 +125,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "LLaDAModelLM": "llada", "LLaMAForCausalLM": "llama", "Lfm25AudioTokenizer": "lfm2", + "Lfm2BidirectionalModel": "lfm2", "Lfm2ForCausalLM": "lfm2", "Lfm2Model": "lfm2", "Lfm2MoeForCausalLM": "lfm2", @@ -133,6 +136,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "LlamaModel": "llama", "Eagle3DraftModel": "llama", "Eagle3Speculator": "llama", + "Eagle3LlamaForCausalLM": "llama", "LlamaForCausalLMEagle3": "llama", "LlavaForConditionalGeneration": "llama", "LlavaStableLMEpochForCausalLM": "stablelm", @@ -231,6 +235,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "UMT5ForConditionalGeneration": "t5", "UMT5Model": "t5", "UltravoxModel": "ultravox", + "UnlimitedOCRForCausalLM": "deepseek", "VLlama3ForCausalLM": "llama", "VoxtralForConditionalGeneration": "llama", "WavTokenizerDec": "wavtokenizer", @@ -261,6 +266,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = { "GlmasrModel": "ultravox", "Granite4VisionForConditionalGeneration": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "HunYuanVLForConditionalGeneration": "hunyuan", "Idefics3ForConditionalGeneration": "smolvlm", "InternVisionModel": "internvl", @@ -296,6 +302,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = { "StepVLForConditionalGeneration": "step3", "Step3p7ForConditionalGeneration": "step3", "UltravoxModel": "ultravox", + "UnlimitedOCRForCausalLM": "deepseek", "VoxtralForConditionalGeneration": "ultravox", "YoutuVLForConditionalGeneration": "youtuvl", } diff --git a/conversion/bailingmoe.py b/conversion/bailingmoe.py index 319ff6dabe..2c6425cb64 100644 --- a/conversion/bailingmoe.py +++ b/conversion/bailingmoe.py @@ -126,7 +126,7 @@ class BailingMoeV2Model(TextModel): if (rope_dim := hparams.get("head_dim")) is None: rope_dim = hparams["hidden_size"] // hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) diff --git a/conversion/base.py b/conversion/base.py index c872bcbb3c..08fd3747c4 100644 --- a/conversion/base.py +++ b/conversion/base.py @@ -1119,8 +1119,10 @@ class TextModel(ModelBase): rope_theta = self.find_hparam(["global_rope_theta", "rope_global_theta", "rope_theta_global", "rope_theta", "rotary_emb_base"], optional=True) local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "rope_theta_local", "swa_rope_theta", "rope_local_base_freq"], optional=True) + partial_rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"], optional=True) + original_max_position_embeddings = self.find_hparam(["original_max_position_embeddings"], optional=True) - # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters + # Ensure global params are mirrored in rope_parameters if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: if local_rope_theta is not None: self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta} @@ -1128,6 +1130,10 @@ class TextModel(ModelBase): self.rope_parameters["rope_theta"] = rope_theta if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: self.rope_parameters["rope_type"] = rope_type + if "partial_rotary_factor" not in self.rope_parameters and partial_rotary_factor is not None: + self.rope_parameters["partial_rotary_factor"] = partial_rotary_factor + if "original_max_position_embeddings" not in self.rope_parameters and original_max_position_embeddings is not None: + self.rope_parameters["original_max_position_embeddings"] = original_max_position_embeddings @classmethod def __init_subclass__(cls): diff --git a/conversion/chatglm.py b/conversion/chatglm.py index 7e323b8900..801913075d 100644 --- a/conversion/chatglm.py +++ b/conversion/chatglm.py @@ -148,7 +148,7 @@ class ChatGLMModel(TextModel): rope_dim = self.hparams["attention_dim"] else: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5))) self.gguf_writer.add_add_bos_token(False) rope_freq = 10000 if "rope_ratio" in self.hparams: diff --git a/conversion/deci.py b/conversion/deci.py index 46d8568c5a..be446eefa6 100644 --- a/conversion/deci.py +++ b/conversion/deci.py @@ -161,7 +161,7 @@ class DeciModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/deepseek.py b/conversion/deepseek.py index 72520cc9f6..4c93fb66df 100644 --- a/conversion/deepseek.py +++ b/conversion/deepseek.py @@ -14,7 +14,7 @@ from .base import MmprojModel, ModelBase, TextModel, gguf, logger from .qwen import QwenModel -@ModelBase.register("DeepseekOCRForCausalLM") +@ModelBase.register("DeepseekOCRForCausalLM", "UnlimitedOCRForCausalLM") class DeepseekOCRVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -205,6 +205,8 @@ class DeepseekModel(TextModel): @ModelBase.register( "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", + "DeepseekOCRForCausalLM", + "UnlimitedOCRForCausalLM", "KimiVLForConditionalGeneration", "KimiK25ForConditionalGeneration", "YoutuForCausalLM", @@ -224,7 +226,7 @@ class DeepseekV2Model(TextModel): self.origin_hf_arch = hparams.get('architectures', [None])[0] # special handling for Deepseek OCR - if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM"): + if self.origin_hf_arch in ("DeepseekOCRForCausalLM", "DeepseekOCR2ForCausalLM", "UnlimitedOCRForCausalLM"): self.model_arch = gguf.MODEL_ARCH.DEEPSEEK2OCR self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch] self.gguf_writer.add_architecture() @@ -350,6 +352,12 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) + # Unlimited-OCR sliding window; written for metadata, the decoder ignores it (full MHA) + if is_ocr: + sliding_window = hparams.get("sliding_window_size") or hparams.get("sliding_window") + if sliding_window: + self.gguf_writer.add_sliding_window(sliding_window) + if (rope_mscale_all := self.rope_parameters.get("mscale_all_dim")) is not None: # [TAG_DEEPSEEK2_YARN_LOG_MUL_FIX] # note: for legacy reasons, this is not consistent with the other usages of self.gguf_writer.add_rope_scaling_yarn_log_mul diff --git a/conversion/exaone.py b/conversion/exaone.py index b21f027842..bc4fb3f1b1 100644 --- a/conversion/exaone.py +++ b/conversion/exaone.py @@ -24,7 +24,7 @@ class ExaoneModel(TextModel): assert (hparams["activation_function"] == "silu") - rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"], optional=True) + rotary_factor = self.rope_parameters.get("partial_rotary_factor") rotary_factor = rotary_factor if rotary_factor is not None else 1.0 self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) @@ -39,7 +39,7 @@ class ExaoneModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor @@ -104,7 +104,7 @@ class Exaone4Model(TextModel): factor = rope_params.get("factor", 16.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/gemma.py b/conversion/gemma.py index 5b4ca5c583..c552df732b 100644 --- a/conversion/gemma.py +++ b/conversion/gemma.py @@ -693,7 +693,7 @@ class Gemma4Model(Gemma3Model): self.gguf_writer.add_head_count_kv(value_arr) # handle n_rot differently for global vs swa layers - partial_rotary_factor_swa = self.hparams.get("partial_rotary_factor", 1.0) + partial_rotary_factor_swa = self.rope_parameters.get("partial_rotary_factor", 1.0) n_rot_full = int(head_dim_full) # "proportional" is used, see generate_extra_tensors n_rot_swa = int(head_dim_swa * partial_rotary_factor_swa) self.gguf_writer.add_rope_dimension_count(n_rot_full) diff --git a/conversion/glm.py b/conversion/glm.py index 641937720d..895cefc22b 100644 --- a/conversion/glm.py +++ b/conversion/glm.py @@ -124,7 +124,7 @@ class Glm4MoeModel(TextModel): self.hparams["hidden_size"] // self.hparams["num_attention_heads"] ) self.gguf_writer.add_rope_dimension_count( - int(rope_dim * self.hparams.get("partial_rotary_factor", 0.5)) + int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.5)) ) # MoE parameters - Use only routed expert count (shared experts handled separately) @@ -226,7 +226,7 @@ class GlmMoeDsaModel(DeepseekV2Model): super().set_gguf_parameters() rope_dim = self.hparams["qk_rope_head_dim"] - partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) + partial_rotary_factor = self.rope_parameters.get("partial_rotary_factor", 1.0) self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) # NextN/MTP prediction layers diff --git a/conversion/granite.py b/conversion/granite.py index 53441fe570..8367ed225d 100644 --- a/conversion/granite.py +++ b/conversion/granite.py @@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GraniteSpeechPlusForConditionalGeneration") +class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel): + """Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation""" + has_vision_encoder = False + has_audio_encoder = True + + def set_gguf_parameters(self): + assert self.hparams_audio is not None + super().set_gguf_parameters() + + # Add feature_layer if present in encoder config + if feature_layers := self.hparams_audio.get("cat_hidden_layers"): + self.gguf_writer.add_audio_feature_layers(feature_layers) + logger.info(f"gguf: audio feature_layers = {feature_layers}") + + # Validate projector dimension matches concatenated encoder output + hidden_dim = self.hparams_audio["hidden_dim"] + expected_dim = hidden_dim * (len(feature_layers) + 1) + projector_dim = self.global_config["projector_config"]["encoder_hidden_size"] + + if projector_dim != expected_dim: + raise ValueError( + f"Projector encoder_hidden_size ({projector_dim}) does not match " + f"expected concatenated dimension ({expected_dim}). " + f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}" + ) + + @ModelBase.register("Granite4VisionForConditionalGeneration") class Granite4VisionMmprojModel(MmprojModel): has_vision_encoder = True diff --git a/conversion/lfm2.py b/conversion/lfm2.py index f28fccf10f..70ce45658b 100644 --- a/conversion/lfm2.py +++ b/conversion/lfm2.py @@ -64,11 +64,17 @@ class LFM2Model(TextModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Lfm2Model") +@ModelBase.register("Lfm2Model", "Lfm2BidirectionalModel") class LFM2ColBertModel(LFM2Model): model_arch = gguf.MODEL_ARCH.LFM2 dense_tensor_name = "dense_2" + def set_gguf_parameters(self): + super().set_gguf_parameters() + if self.hf_arch == "Lfm2BidirectionalModel": + self.gguf_writer.add_causal_attention(False) + self._try_set_pooling_type() + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if not name.startswith(self.dense_tensor_name): name = "model." + name @@ -76,10 +82,11 @@ class LFM2ColBertModel(LFM2Model): yield from super().modify_tensors(data_torch, name, bid) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: - # dense tensor is stored in a separate safetensors file + # optional dense tensor is stored in a separate safetensors file from safetensors.torch import load_file tensors_file = self.dir_model / "1_Dense" / "model.safetensors" - assert tensors_file.is_file() + if not tensors_file.is_file(): + return tensor = load_file(tensors_file)["linear.weight"] self.gguf_writer.add_embedding_length_out(tensor.shape[0]) yield f"{self.dense_tensor_name}.weight", tensor.clone() diff --git a/conversion/llama.py b/conversion/llama.py index b87bf92d46..b43cc994aa 100644 --- a/conversion/llama.py +++ b/conversion/llama.py @@ -23,6 +23,7 @@ from .base import ModelBase, TextModel, gguf, logger "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", "LlamaForCausalLMEagle3", + "Eagle3LlamaForCausalLM", "Eagle3Speculator", "Eagle3DraftModel", "IQuestCoderForCausalLM", @@ -289,7 +290,7 @@ class LlamaModel(TextModel): factor = rope_params.get("factor", 8.0) low_freq_factor = rope_params.get("low_freq_factor", 1.0) high_freq_factor = rope_params.get("high_freq_factor", 4.0) - old_context_len = self.hparams.get("original_max_position_embeddings", 8192) + old_context_len = rope_params.get("original_max_position_embeddings", 8192) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/conversion/mimo.py b/conversion/mimo.py index d4067aab4b..11ec286794 100644 --- a/conversion/mimo.py +++ b/conversion/mimo.py @@ -154,7 +154,7 @@ class MimoV2Model(TextModel): self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) - rope_dim = int(self.hparams["head_dim"] * self.hparams["partial_rotary_factor"]) + rope_dim = int(self.hparams["head_dim"] * self.rope_parameters["partial_rotary_factor"]) self.gguf_writer.add_rope_dimension_count(rope_dim) self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("layernorm_epsilon", 1e-5)) diff --git a/conversion/minicpm.py b/conversion/minicpm.py index e9a4c4a74d..e31b26a008 100644 --- a/conversion/minicpm.py +++ b/conversion/minicpm.py @@ -32,11 +32,9 @@ class MiniCPMModel(TextModel): def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: rope_dims = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is not None: - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if long_factors or short_factors: if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') @@ -85,13 +83,11 @@ class MiniCPM3Model(TextModel): self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is not None: + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if long_factors or short_factors: rope_dims = self.hparams["qk_rope_head_dim"] - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') diff --git a/conversion/nemotron.py b/conversion/nemotron.py index dfeeb97858..e44688a788 100644 --- a/conversion/nemotron.py +++ b/conversion/nemotron.py @@ -125,17 +125,18 @@ class NemotronModel(TextModel): self.gguf_writer.add_layer_norm_eps(f_norm_eps) # * Partial RoPE - rot_pct = self.find_hparam(["partial_rotary_factor", "rope_pct", "rope_percent"]) + rot_pct = self.rope_parameters["partial_rotary_factor"] n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) self.gguf_writer.add_rope_dimension_count(int(rot_pct * n_embd) // n_head) # * RopeScaling for Nemotron - if "rope_scaling" not in self.hparams or self.hparams["rope_scaling"] is None: + factor = self.hparams.get("factor") or self.rope_parameters.get("factor") + if factor is None: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) else: self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.LINEAR) - self.gguf_writer.add_rope_scaling_factor(self.hparams["factor"]) + self.gguf_writer.add_rope_scaling_factor(factor) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # * Adding +1 to LayerNorm's weights here to implement layernorm1p w/o changing anything on the GGML engine side diff --git a/conversion/phi.py b/conversion/phi.py index 5e0d72847a..df4bfe809a 100644 --- a/conversion/phi.py +++ b/conversion/phi.py @@ -18,7 +18,7 @@ class Phi2Model(TextModel): model_arch = gguf.MODEL_ARCH.PHI2 def set_gguf_parameters(self): - rot_pct = self.find_hparam(["partial_rotary_factor"]) + rot_pct = self.rope_parameters["partial_rotary_factor"] n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) @@ -149,8 +149,8 @@ class Phi3MiniModel(TextModel): n_head_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"]) rms_eps = self.find_hparam(["rms_norm_eps"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) - orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"] + rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0) rope_dims = int(rot_pct * n_embd) // n_head self.gguf_writer.add_context_length(max_pos_embds) @@ -174,18 +174,19 @@ class Phi3MiniModel(TextModel): n_embd = self.find_hparam(["hidden_size", "n_embd"]) n_head = self.find_hparam(["num_attention_heads", "n_head"]) max_pos_embds = self.find_hparam(["n_positions", "max_position_embeddings"]) - orig_max_pos_embds = self.find_hparam(["original_max_position_embeddings"]) - rot_pct = self.hparams.get("partial_rotary_factor", 1.0) + orig_max_pos_embds = self.rope_parameters["original_max_position_embeddings"] + rot_pct = self.rope_parameters.get("partial_rotary_factor", 1.0) rope_dims = int(rot_pct * n_embd) // n_head # write rope scaling for long context (128k) model - rope_scaling = self.find_hparam(['rope_scaling'], True) - if rope_scaling is None: + long_factors = self.rope_parameters.get('long_factor') + short_factors = self.rope_parameters.get('short_factor') + if not long_factors: return scale = max_pos_embds / orig_max_pos_embds - rope_scaling_type = rope_scaling.get('rope_type', rope_scaling.get('type', '')).lower() + rope_scaling_type = self.rope_parameters.get('rope_type', '').lower() if len(rope_scaling_type) == 0: raise KeyError('Missing the required key rope_scaling.type') @@ -198,9 +199,6 @@ class Phi3MiniModel(TextModel): self.gguf_writer.add_rope_scaling_attn_factors(attn_factor) - long_factors = rope_scaling.get('long_factor', None) - short_factors = rope_scaling.get('short_factor', None) - if long_factors is None or short_factors is None: raise KeyError('Missing the required key rope_scaling.long_factor or rope_scaling_short_factor') diff --git a/conversion/qwen.py b/conversion/qwen.py index 7eb135c832..6b85eb9aaf 100644 --- a/conversion/qwen.py +++ b/conversion/qwen.py @@ -280,7 +280,7 @@ class Qwen3NextModel(Qwen2MoeModel): self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4)) if (rope_dim := self.hparams.get("head_dim")) is None: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] - self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.rope_parameters.get("partial_rotary_factor", 0.25))) @classmethod def filter_tensors(cls, item: tuple[str, Callable[[], Tensor]]) -> tuple[str, Callable[[], Tensor]] | None: diff --git a/conversion/stablelm.py b/conversion/stablelm.py index ba5e9aa6ca..6e16378a03 100644 --- a/conversion/stablelm.py +++ b/conversion/stablelm.py @@ -28,7 +28,7 @@ class StableLMModel(TextModel): self.gguf_writer.add_embedding_length(hparams["hidden_size"]) self.gguf_writer.add_block_count(self.block_count) self.gguf_writer.add_feed_forward_length(hparams["intermediate_size"]) - rotary_factor = self.find_hparam(["partial_rotary_factor", "rope_pct"]) + rotary_factor = self.rope_parameters["partial_rotary_factor"] self.gguf_writer.add_rope_dimension_count(int(rotary_factor * (hparams["hidden_size"] // hparams["num_attention_heads"]))) self.gguf_writer.add_head_count(hparams["num_attention_heads"]) self.gguf_writer.add_head_count_kv(hparams["num_key_value_heads"]) diff --git a/conversion/step3.py b/conversion/step3.py index 8c45b61c95..49bb5244a6 100644 --- a/conversion/step3.py +++ b/conversion/step3.py @@ -314,7 +314,7 @@ class Step35Model(TextModel): factor = float(rope_params.get("factor", 8.0)) low_freq_factor = float(rope_params.get("low_freq_factor", 1.0)) high_freq_factor = float(rope_params.get("high_freq_factor", 4.0)) - old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192))) + old_context_len = int(rope_params.get("original_max_position_embeddings", 8192)) low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/docs/android.md b/docs/android.md index 964ce8a1f0..e8d580a9ed 100644 --- a/docs/android.md +++ b/docs/android.md @@ -29,7 +29,7 @@ With Termux, you can install and run `llama.cpp` as if the environment were Linu ``` $ apt update && apt upgrade -y -$ apt install git cmake +$ apt install git cmake libandroid-spawn ``` Then, follow the [build instructions](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md), specifically for CMake. diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index d482d88408..8b0b9a1869 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -413,6 +413,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c |------------------|----------------------------------------| | Single device | --split-mode none --main-gpu DEVICE_ID | | Multiple devices | --split-mode layer (default) | +| Multiple devices | --split-mode tensor (tensor parallelism) | + +`--split-mode tensor` (tensor parallelism) shards each layer across the selected +GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is +left at its default `auto`, so `--split-mode tensor` works out of the box. +Passing `--flash-attn off` together with `--split-mode tensor` is rejected at +context creation. The default `f16` KV cache is recommended. Tensor parallelism +is currently optimized for 2 GPUs; other device counts fall back to a generic +all-reduce. Examples: @@ -715,6 +724,15 @@ In two device selection modes, the default SYCL backend is level_zero, you can c |------------------|----------------------------------------| | Single device | --split-mode none --main-gpu DEVICE_ID | | Multiple devices | --split-mode layer (default) | +| Multiple devices | --split-mode tensor (tensor parallelism) | + +`--split-mode tensor` (tensor parallelism) shards each layer across the selected +GPUs. It requires flash attention, which is auto-enabled when `--flash-attn` is +left at its default `auto`, so `--split-mode tensor` works out of the box. +Passing `--flash-attn off` together with `--split-mode tensor` is rejected at +context creation. The default `f16` KV cache is recommended. Tensor parallelism +is currently optimized for 2 GPUs; other device counts fall back to a generic +all-reduce. Examples: diff --git a/docs/backend/snapdragon/CMakeUserPresets.json b/docs/backend/snapdragon/CMakeUserPresets.json index d37100764f..848d735f1c 100644 --- a/docs/backend/snapdragon/CMakeUserPresets.json +++ b/docs/backend/snapdragon/CMakeUserPresets.json @@ -24,7 +24,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "ON", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, @@ -47,7 +46,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "ON", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, @@ -73,7 +71,6 @@ "GGML_LLAMAFILE": "OFF", "GGML_OPENCL": "OFF", "GGML_HEXAGON": "ON", - "GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE": "128", "LLAMA_OPENSSL": "OFF" } }, diff --git a/docs/speculative.md b/docs/speculative.md index 43d1818589..8f91256c4a 100644 --- a/docs/speculative.md +++ b/docs/speculative.md @@ -13,6 +13,45 @@ The `llama-server` application supports several implementations of speculative d A much smaller model (called the _draft model_) generates drafts. A draft model is the most used approach in speculative decoding. +### EAGLE-3 (`draft-eagle3`) + +EAGLE-3 uses a small draft model that reads the target model's hidden states to predict the next tokens, so it +reaches higher acceptance than a standalone draft model of the same size. The draft is a one-layer transformer +trained for a specific target model; it shares the target model's tokenizer and, optionally, uses a reduced draft +vocabulary with its own `lm_head`, which is mapped back using a `d2t` table. + +Convert the EAGLE-3 checkpoint with `--target-model-dir` so it inherits the target's tokenizer and the layer +indices to read. Both the SpecForge `LlamaForCausalLMEagle3` and the vLLM/AngelSlim `Eagle3LlamaForCausalLM` +checkpoint formats are supported (for example [`AngelSlim/Qwen3-4B_eagle3`](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3) +for `Qwen/Qwen3-4B`): + +```bash +python convert_hf_to_gguf.py AngelSlim/Qwen3-4B_eagle3 \ + --target-model-dir Qwen/Qwen3-4B --outtype bf16 --outfile Qwen3-4B-eagle3.gguf + +llama-server -m Qwen3-4B.gguf -md Qwen3-4B-eagle3.gguf --spec-type draft-eagle3 +``` + +Supported EAGLE-3 draft models include: + +- [yuhuili/EAGLE3-LLaMA3.1-Instruct-8B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.1-Instruct-8B) +- [yuhuili/EAGLE3-LLaMA3.3-Instruct-70B](https://huggingface.co/yuhuili/EAGLE3-LLaMA3.3-Instruct-70B) +- [RedHatAI/gemma-4-31B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-31B-it-speculator.eagle3) +- [RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3](https://huggingface.co/RedHatAI/gemma-4-26B-A4B-it-speculator.eagle3) +- [Tengyunw/qwen3_8b_eagle3](https://huggingface.co/Tengyunw/qwen3_8b_eagle3) +- [Tengyunw/qwen3_30b_moe_eagle3](https://huggingface.co/Tengyunw/qwen3_30b_moe_eagle3) +- [AngelSlim/Qwen3-1.7B_eagle3](https://huggingface.co/AngelSlim/Qwen3-1.7B_eagle3) +- [AngelSlim/Qwen3-4B_eagle3](https://huggingface.co/AngelSlim/Qwen3-4B_eagle3) +- [AngelSlim/Qwen3-8B_eagle3](https://huggingface.co/AngelSlim/Qwen3-8B_eagle3) +- [AngelSlim/Qwen3-14B_eagle3](https://huggingface.co/AngelSlim/Qwen3-14B_eagle3) +- [AngelSlim/Qwen3-32B_eagle3](https://huggingface.co/AngelSlim/Qwen3-32B_eagle3) +- [AngelSlim/Qwen3-a3B_eagle3](https://huggingface.co/AngelSlim/Qwen3-a3B_eagle3) +- [RedHatAI/gpt-oss-20b-speculator.eagle3](https://huggingface.co/RedHatAI/gpt-oss-20b-speculator.eagle3) +- [lmsys/EAGLE3-gpt-oss-120b-bf16](https://huggingface.co/lmsys/EAGLE3-gpt-oss-120b-bf16) +- [nvidia/gpt-oss-120b-Eagle3-long-context](https://huggingface.co/nvidia/gpt-oss-120b-Eagle3-long-context) + +For the full and up-to-date list of supported models, see #18039. + ### n-gram Cache (`ngram-cache`) An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences. @@ -108,7 +147,7 @@ If a draft model is combined with a draftless decoding the draftless decoding ha ### General Speculative Parameters ``` ---spec-type [none|draft-simple|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod] +--spec-type [none|draft-simple|draft-eagle3|draft-mtp|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod] comma-separated list of types of speculative decoding to use (default: none) (env: LLAMA_ARG_SPEC_TYPE) @@ -247,6 +286,7 @@ Specifies a comma-separated list of speculative decoding types to use. |------|-------------| | `none` | No speculative decoding (default) | | `draft-simple` | Use a simple draft model for speculation | +| `draft-eagle3` | Use an EAGLE-3 draft model that reads the target's hidden states | | `draft-mtp` | Use Multi Token Prediction (MTP) heads from the main model | | `ngram-cache` | Use n-gram cache lookup | | `ngram-simple` | Use simple n-gram pattern matching | diff --git a/examples/json_schema_to_grammar.py b/examples/json_schema_to_grammar.py index 077fcfacac..83abd259da 100755 --- a/examples/json_schema_to_grammar.py +++ b/examples/json_schema_to_grammar.py @@ -198,18 +198,18 @@ class BuiltinRule: SPACE_RULE = '| " " | "\\n"{1,2} [ \\t]{0,20}' PRIMITIVE_RULES = { - 'boolean' : BuiltinRule('("true" | "false") space', []), + 'boolean' : BuiltinRule('("true" | "false")', []), 'decimal-part' : BuiltinRule('[0-9]{1,16}', []), 'integral-part': BuiltinRule('[0] | [1-9] [0-9]{0,15}', []), - 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)? space', ['integral-part', 'decimal-part']), - 'integer' : BuiltinRule('("-"? integral-part) space', ['integral-part']), + 'number' : BuiltinRule('("-"? integral-part) ("." decimal-part)? ([eE] [-+]? integral-part)?', ['integral-part', 'decimal-part']), + 'integer' : BuiltinRule('("-"? integral-part)', ['integral-part']), 'value' : BuiltinRule('object | array | string | number | boolean | null', ['object', 'array', 'string', 'number', 'boolean', 'null']), - 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? "}" space', ['string', 'value']), - 'array' : BuiltinRule('"[" space ( value ("," space value)* )? "]" space', ['value']), - 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\"" space', []), + 'object' : BuiltinRule('"{" space ( string ":" space value ("," space string ":" space value)* )? space "}"', ['string', 'value']), + 'array' : BuiltinRule('"[" space ( value ("," space value)* )? space "]"', ['value']), + 'uuid' : BuiltinRule(r'"\"" [0-9a-fA-F]{8} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{4} "-" [0-9a-fA-F]{12} "\""', []), 'char' : BuiltinRule(r'[^"\\\x7F\x00-\x1F] | [\\] (["\\bfnrt] | "u" [0-9a-fA-F]{4})', []), - 'string' : BuiltinRule(r'"\"" char* "\"" space', ['char']), - 'null' : BuiltinRule('"null" space', []), + 'string' : BuiltinRule(r'"\"" char* "\""', ['char']), + 'null' : BuiltinRule('"null"', []), } # TODO: support "uri", "email" string formats @@ -217,9 +217,9 @@ STRING_FORMAT_RULES = { 'date' : BuiltinRule('[0-9]{4} "-" ( "0" [1-9] | "1" [0-2] ) "-" ( \"0\" [1-9] | [1-2] [0-9] | "3" [0-1] )', []), 'time' : BuiltinRule('([01] [0-9] | "2" [0-3]) ":" [0-5] [0-9] ":" [0-5] [0-9] ( "." [0-9]{3} )? ( "Z" | ( "+" | "-" ) ( [01] [0-9] | "2" [0-3] ) ":" [0-5] [0-9] )', []), 'date-time' : BuiltinRule('date "T" time', ['date', 'time']), - 'date-string' : BuiltinRule('"\\"" date "\\"" space', ['date']), - 'time-string' : BuiltinRule('"\\"" time "\\"" space', ['time']), - 'date-time-string': BuiltinRule('"\\"" date-time "\\"" space', ['date-time']), + 'date-string' : BuiltinRule('"\\"" date "\\""', ['date']), + 'time-string' : BuiltinRule('"\\"" time "\\""', ['time']), + 'date-time-string': BuiltinRule('"\\"" date-time "\\""', ['date-time']), } DOTALL = '[\\U00000000-\\U0010FFFF]' @@ -319,7 +319,7 @@ class SchemaConverter: out.append(f'[^"{"".join(rejects)}] {char_rule}*') visit(trie) - out.append(f' ){"" if trie.is_end_of_string else "?"} ["] space') + out.append(f' ){"" if trie.is_end_of_string else "?"} ["]') return ''.join(out) def _add_rule(self, name, rule): @@ -549,7 +549,7 @@ class SchemaConverter: return self._add_rule( name, to_rule(transform()) if self._raw_pattern \ - else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\" space") + else "\"\\\"\" (" + to_rule(transform()) + ") \"\\\"\"") def _resolve_ref(self, ref): @@ -580,10 +580,10 @@ class SchemaConverter: return self._add_rule(rule_name, self._generate_union_rule(name, [{**schema, 'type': t} for t in schema_type])) elif 'const' in schema: - return self._add_rule(rule_name, self._generate_constant_rule(schema['const']) + ' space') + return self._add_rule(rule_name, self._generate_constant_rule(schema['const'])) elif 'enum' in schema: - rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ') space' + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in schema['enum'])) + ')' return self._add_rule(rule_name, rule) elif schema_type in (None, 'object') and \ @@ -624,7 +624,7 @@ class SchemaConverter: enum_intersection &= s if enum_intersection: - rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space' + rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ')' return self._add_rule(rule_name, rule) return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) @@ -638,12 +638,12 @@ class SchemaConverter: ' "," space '.join( self.visit(item, f'{name}{"-" if name else ""}tuple-{i}') for i, item in enumerate(items)) + - ' "]" space') + ' space "]"') else: item_rule_name = self.visit(items, f'{name}{"-" if name else ""}item') min_items = schema.get("minItems", 0) max_items = schema.get("maxItems") - return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' "]" space') + return self._add_rule(rule_name, '"[" space ' + _build_repetition(item_rule_name, min_items, max_items, separator_rule='"," space') + ' space "]"') elif schema_type in (None, 'string') and 'pattern' in schema: return self._visit_pattern(schema['pattern'], rule_name) @@ -663,7 +663,7 @@ class SchemaConverter: min_len = schema.get('minLength', 0) max_len = schema.get('maxLength') - return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\"" space') + return self._add_rule(rule_name, r'"\"" ' + _build_repetition(char_rule, min_len, max_len) + r' "\""') elif schema_type in (None, 'integer') and \ ('minimum' in schema or 'exclusiveMinimum' in schema or 'maximum' in schema or 'exclusiveMaximum' in schema): @@ -680,7 +680,7 @@ class SchemaConverter: out = ["("] _generate_min_max_int(min_value, max_value, out) - out.append(") space") + out.append(")") return self._add_rule(rule_name, ''.join(out)) elif (schema_type == 'object') or (len(schema) == 0): @@ -765,7 +765,7 @@ class SchemaConverter: rule += ' )' rule += ' )?' - rule += ' "}" space' + rule += ' space "}"' return rule diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 0507e0c5aa..a0cd4e7158 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -5,7 +5,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 15) -set(GGML_VERSION_PATCH 1) +set(GGML_VERSION_PATCH 2) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/") @@ -266,7 +266,6 @@ set (GGML_OPENCL_TARGET_VERSION "300" CACHE STRING "ggml: OpenCL API version to target") option(GGML_HEXAGON "ggml: enable Hexagon backend" OFF) -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml: quantize group size (32, 64, or 128)") # toolchain for vulkan-shaders-gen set (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN "" CACHE FILEPATH "ggml: toolchain file for vulkan-shaders-gen") diff --git a/ggml/include/ggml-sycl.h b/ggml/include/ggml-sycl.h index 5ce349a880..418a7ba978 100644 --- a/ggml/include/ggml-sycl.h +++ b/ggml/include/ggml-sycl.h @@ -27,6 +27,14 @@ GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_buffer_type(int de // split tensor buffer that splits matrices by rows across multiple devices GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_split_buffer_type(const float * tensor_split); +// Tensor parallelism (--split-mode tensor): comm_init/free/allreduce_tensor +// trio queried by the meta-backend via ggml_backend_reg_get_proc_address. +// See typedefs in ggml/include/ggml-backend.h. Mirrors the CUDA backend's +// pattern (ggml_backend_cuda_comm_*). +GGML_BACKEND_API void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends); +GGML_BACKEND_API void ggml_backend_sycl_comm_free(void * comm_ctx); +GGML_BACKEND_API bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx, struct ggml_tensor ** tensors); + // pinned host buffer for use with the CPU backend for faster copies between CPU and GPU GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_sycl_host_buffer_type(void); diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index d9383a04be..9f3a744b5d 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -2417,15 +2417,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - parallel_for_ggml(params, n_batch, [&](int begin, int end) { - for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + parallel_for_ggml(params, n_batch * M, [&](int begin, int end) { + for (int idx = begin; idx < end; ++idx) { + int batch_idx = idx / M; + int m = idx % M; int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); const float * A_data = (const float *)((const char *)src1->data + src1_offset); char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; - - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, wdata_batch + m * row_size_A, K); - } + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); } }); }); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index e13828e3be..0b8323e60c 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -2345,7 +2345,7 @@ class tinyBLAS_Q0_PPC { else if (n_aligned % 16 == 0) nc = 16; else nc = 8; } - bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + bool can_use_tiled = n_aligned > 0 && (m % mc == 0); if (can_use_tiled) { matmul_tiled(m, n_aligned, mc, nc, kc); if (n > n_aligned) { @@ -3063,13 +3063,14 @@ class tinyBLAS_Q0_PPC { int64_t ii = (job / xtiles) * mc; int64_t jj = (job % xtiles) * nc; for (int64_t kk = 0; kk < k; kk += kc) { + int64_t k_cur = MIN(kc, k - kk); if constexpr(is_Ablock_q4) { - packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack); } else { - packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, k_cur, (uint8_t *)A_pack); } - packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, k_cur, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, k_cur, kk, A_pack, B_pack); } } } diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 74611dce7f..6724686b8a 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3688,8 +3688,6 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -3703,25 +3701,49 @@ static void ggml_compute_forward_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; - float sum = 0.0; - ggml_vec_sum_f32(ne00, &sum, x); - float mean = sum/ne00; + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + const float * xf = (const float *) x; - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - float variance = 0; + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, xf); + float mean = sum/ne00; + + float * yf = (float *) y; + float variance = 0; #ifdef GGML_USE_ACCELERATE - mean = -mean; - vDSP_vsadd(x, 1, &mean, y, 1, ne00); - vDSP_measqv(y, 1, &variance, ne00); + mean = -mean; + vDSP_vsadd(xf, 1, &mean, yf, 1, ne00); + vDSP_measqv(yf, 1, &variance, ne00); #else - variance = ggml_vec_cvar_f32(ne00, y, x, mean); + variance = ggml_vec_cvar_f32(ne00, yf, xf, mean); #endif //GGML_USE_ACCELERATE - const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); + const float scale = 1.0f/sqrtf(variance + eps); + ggml_vec_scale_f32(ne00, yf, scale); + } else { + float sum = 0.0; + for (int64_t i00 = 0; i00 < ne00; i00++) { + sum += *(const float *) (x + i00*nb00); + } + const float mean = sum/ne00; + + float variance = 0.0f; + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float v = *(const float *) (x + i00*nb00) - mean; + *(float *) (y + i00*nb0) = v; + variance += v * v; + } + variance /= ne00; + + const float scale = 1.0f/sqrtf(variance + eps); + for (int64_t i00 = 0; i00 < ne00; i00++) { + *(float *) (y + i00*nb0) *= scale; + } + } } } } @@ -4142,8 +4164,6 @@ static void ggml_compute_forward_l2_norm_f32( GGML_ASSERT(ggml_are_same_shape(src0, dst)); - GGML_ASSERT(src0->nb[0] == sizeof(float)); - const int ith = params->ith; const int nth = params->nth; @@ -4158,20 +4178,27 @@ static void ggml_compute_forward_l2_norm_f32( for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { - const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); + const char * x = (const char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03; ggml_float sum = 0.0; for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)(x[i00] * x[i00]); + const float xi = *(const float *) (x + i00*nb00); + sum += (ggml_float)(xi * xi); } - float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); - - memcpy(y, x, ne00 * sizeof(float)); - const float scale = 1.0f/fmaxf(sqrtf(sum), eps); - ggml_vec_scale_f32(ne00, y, scale); + char * y = (char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3; + + if (nb00 == sizeof(float) && nb0 == sizeof(float)) { + memcpy(y, x, ne00 * sizeof(float)); + ggml_vec_scale_f32(ne00, (float *) y, scale); + } else { + for (int64_t i00 = 0; i00 < ne00; i00++) { + const float xi = *(const float *) (x + i00*nb00); + *(float *) (y + i00*nb0) = xi * scale; + } + } } } } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index c25f42b32b..2e38077bf6 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -34,26 +34,26 @@ template = (uint32_t)ne0 || i1 >= (uint32_t)ne1 || i2 >= (uint32_t)ne2 || i3 >= ne3.z) { + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) { return; } @@ -69,25 +69,32 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint32_t i12 = fastmodulo(i2, ne12); const uint32_t i13 = fastmodulo(i3, ne13); - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01; + const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11; + const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; + const uint32_t s0 = blockDim.x * gridDim.x; + ggml_cuda_pdl_sync(); - for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { + for (uint32_t i0 = i0s; i0 < ne0; i0 += s0) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10*s10]); + result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]); } dst_row[i0] = (dst_t) result; + + // protect i0 from overflow + if (ne0 - i0 <= s0) { + break; + } } } @@ -110,19 +117,19 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne12, const uint3 ne13, /*const int s0,*/ - const int s1, - const int s2, - const int s3, - const int s00, - const int s01, - const int s02, - const int s03, - const int s10, - const int s11, - const int s12, - const int s13, + const uint32_t s1, + const uint32_t s2, + const uint32_t s3, + const uint32_t s00, + const uint32_t s01, + const uint32_t s02, + const uint32_t s03, + const uint32_t s10, + const uint32_t s11, + const uint32_t s12, + const uint32_t s13, src1_ptrs... src1s) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; + const uint32_t i = blockDim.x*blockIdx.x + threadIdx.x; const uint32_t i3 = fastdiv(i, prod_012); const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01); @@ -133,25 +140,25 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, return; } - const int i11 = fastmodulo(i1, ne11); - const int i12 = fastmodulo(i2, ne12); - const int i13 = fastmodulo(i3, ne13); + const uint32_t i11 = fastmodulo(i1, ne11); + const uint32_t i12 = fastmodulo(i2, ne12); + const uint32_t i13 = fastmodulo(i3, ne13); - const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1; + const size_t i_src0 = size_t( i3)*s03 + size_t( i2)*s02 + size_t( i1)*s01; + const size_t i_src1 = size_t(i13)*s13 + size_t(i12)*s12 + size_t(i11)*s11; + const size_t i_dst = size_t( i3)*s3 + size_t( i2)*s2 + size_t( i1)*s1; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; dst_t * dst_row = dst + i_dst; - const int i10 = fastmodulo(i0, ne10); + const uint32_t i10 = fastmodulo(i0, ne10); ggml_cuda_pdl_sync(); - float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; + float result = src0_row ? (float) src0_row[size_t(i0)*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + size_t(i10)*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10*s10]); + result = bin_op(result, (float)src1[i_src1 + size_t(i10)*s10]); } dst_row[i0] = (dst_t) result; @@ -248,6 +255,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t s02 = nb02 / sizeof(src0_t); size_t s03 = nb03 / sizeof(src0_t); + GGML_ASSERT(ne0 <= std::numeric_limits::max()); + GGML_ASSERT(ne1 <= std::numeric_limits::max()); + GGML_ASSERT(ne2 <= std::numeric_limits::max()); + GGML_ASSERT(ne3 <= std::numeric_limits::max()); + + //GGML_ASSERT(s0 <= std::numeric_limits::max()); + GGML_ASSERT(s1 <= std::numeric_limits::max()); + GGML_ASSERT(s2 <= std::numeric_limits::max()); + GGML_ASSERT(s3 <= std::numeric_limits::max()); + + GGML_ASSERT(s00 <= std::numeric_limits::max()); + GGML_ASSERT(s01 <= std::numeric_limits::max()); + GGML_ASSERT(s02 <= std::numeric_limits::max()); + GGML_ASSERT(s03 <= std::numeric_limits::max()); + + GGML_ASSERT(s10 <= std::numeric_limits::max()); + GGML_ASSERT(s11 <= std::numeric_limits::max()); + GGML_ASSERT(s12 <= std::numeric_limits::max()); + GGML_ASSERT(s13 <= std::numeric_limits::max()); + + GGML_ASSERT(cne1[0] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[1] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[2] <= std::numeric_limits::max()); + GGML_ASSERT(cne1[3] <= std::numeric_limits::max()); + GGML_ASSERT(nb0 % sizeof(dst_t) == 0); GGML_ASSERT(nb1 % sizeof(dst_t) == 0); GGML_ASSERT(nb2 % sizeof(dst_t) == 0); @@ -263,6 +295,8 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); + GGML_ASSERT(ne2 * ne3 <= std::numeric_limits::max()); + const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -281,7 +315,13 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]); if (block_nums.z > 65535 || block_nums.y > 65535) { - int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + int64_t block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; + + GGML_ASSERT(block_num <= std::numeric_limits::max()); + GGML_ASSERT(block_num * block_size <= std::numeric_limits::max()); + GGML_ASSERT(ne0 * ne1 <= std::numeric_limits::max()); + GGML_ASSERT(ne0 * ne1 * ne2 <= std::numeric_limits::max()); + const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2)); const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1)); const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0); @@ -298,6 +338,10 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } } else { + GGML_ASSERT(int64_t(block_nums.x) * block_dims.x <= std::numeric_limits::max()); + GGML_ASSERT(int64_t(block_nums.y) * block_dims.y <= std::numeric_limits::max()); + GGML_ASSERT(int64_t(block_nums.z) * block_dims.z <= std::numeric_limits::max()); + const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); { const ggml_cuda_kernel_launch_params launch_params = ggml_cuda_kernel_launch_params(block_nums, block_dims, 0, stream); diff --git a/ggml/src/ggml-cuda/col2im-1d.cu b/ggml/src/ggml-cuda/col2im-1d.cu new file mode 100644 index 0000000000..fecd4c6a95 --- /dev/null +++ b/ggml/src/ggml-cuda/col2im-1d.cu @@ -0,0 +1,81 @@ +#include "col2im-1d.cuh" +#include "convert.cuh" + +// col2im_1d: scatter-add GEMM columns to 1D signal (gather approach) +// columns: [K*OC, T_in] -> output: [T_out, OC] +// Supports F32, F16, BF16 data with F32 accumulator. + +template +static __global__ void col2im_1d_kernel( + const T * __restrict__ col, + T * __restrict__ dst, + const int T_in, const uint3 T_out_fd, + const int OC, const int K, const int K_OC, + const int s0, const int p0, const int total) { + + const int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= total) return; + + // dst layout: [T_out, OC], ne[0]=T_out fastest + const uint2 qr = fast_div_modulo((uint32_t)idx, T_out_fd); // qr.x = idx / T_out, qr.y = idx % T_out + const int oc = (int)qr.x; + const int t_out = (int)qr.y; + const int t_abs = t_out + p0; // absolute position in uncropped signal + + // Gather: find all (t_in, k) where t_in*s + k == t_abs, 0 <= k < K + int t_in_min = (t_abs - K + s0) / s0; // ceil((t_abs - K + 1) / s) + if (t_in_min < 0) t_in_min = 0; + int t_in_max = t_abs / s0; + if (t_in_max >= T_in) t_in_max = T_in - 1; + + float sum = 0.0f; + for (int t_in = t_in_min; t_in <= t_in_max; t_in++) { + const int k = t_abs - t_in * s0; + // col layout: [K*OC, T_in], column index = oc * K + k + sum += ggml_cuda_cast(col[(oc * K + k) + t_in * K_OC]); + } + + dst[idx] = ggml_cuda_cast(sum); +} + +void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t OC = ((const int32_t *)(dst->op_params))[1]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[2]; + + const int K_OC = (int) src0->ne[0]; + const int T_in = (int) src0->ne[1]; + const int K = K_OC / OC; + const int T_out = (int) dst->ne[0]; + + const uint3 T_out_fd = init_fastdiv_values((uint32_t)T_out); + + const int total = T_out * OC; + const int block_size = 256; + const int num_blocks = (total + block_size - 1) / block_size; + + switch (src0->type) { + case GGML_TYPE_F32: { + col2im_1d_kernel<<>>( + (const float *)src0->data, (float *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + case GGML_TYPE_F16: { + col2im_1d_kernel<<>>( + (const half *)src0->data, (half *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + case GGML_TYPE_BF16: { + col2im_1d_kernel<<>>( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + T_in, T_out_fd, OC, K, K_OC, s0, p0, total); + } break; + default: + GGML_ABORT("col2im_1d: unsupported type"); + } +} diff --git a/ggml/src/ggml-cuda/col2im-1d.cuh b/ggml/src/ggml-cuda/col2im-1d.cuh new file mode 100644 index 0000000000..efc3313c4d --- /dev/null +++ b/ggml/src/ggml-cuda/col2im-1d.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_col2im_1d(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 34cdbc81c2..cca70592f8 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -11,6 +11,7 @@ #include "ggml-cuda/argsort.cuh" #include "ggml-cuda/binbcast.cuh" #include "ggml-cuda/clamp.cuh" +#include "ggml-cuda/col2im-1d.cuh" #include "ggml-cuda/concat.cuh" #include "ggml-cuda/conv-transpose-1d.cuh" #include "ggml-cuda/conv2d.cuh" @@ -3051,6 +3052,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CONV_TRANSPOSE_1D: ggml_cuda_op_conv_transpose_1d(ctx,dst); break; + case GGML_OP_COL2IM_1D: + ggml_cuda_op_col2im_1d(ctx, dst); + break; case GGML_OP_POOL_2D: ggml_cuda_op_pool2d(ctx, dst); break; @@ -5316,13 +5320,21 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g } return false; } break; + case GGML_OP_COL2IM_1D: + { + ggml_type src0_type = op->src[0]->type; + return (src0_type == GGML_TYPE_F32 || src0_type == GGML_TYPE_F16 || src0_type == GGML_TYPE_BF16) && + op->type == src0_type && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op); + } break; case GGML_OP_SILU_BACK: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; break; case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_L2_NORM: - return true; + return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_RMS_NORM_BACK: return ggml_is_contiguous(op->src[0]); break; diff --git a/ggml/src/ggml-hexagon/CMakeLists.txt b/ggml/src/ggml-hexagon/CMakeLists.txt index b82bae0c10..c6e49a71d1 100644 --- a/ggml/src/ggml-hexagon/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/CMakeLists.txt @@ -25,7 +25,6 @@ include(ExternalProject) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) option(GGML_HEXAGON_FA_EXP2_HF "ggml-hexagon: use FP16 exp2 polynomial in FA softmax instead of F32 exp round-trip" OFF) set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate") -set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") add_library(htp_iface OBJECT ${CMAKE_CURRENT_BINARY_DIR}/htp_iface_stub.c) @@ -72,15 +71,12 @@ function(build_htp_skel V) -DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT} -DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT} -DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} - -DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE} -DDSP_VERSION=${V} -DPREBUILT_LIB_DIR="toolv19_${V}") list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so) set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE) endfunction() -build_htp_skel(v68) -build_htp_skel(v69) build_htp_skel(v73) build_htp_skel(v75) build_htp_skel(v79) diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index e612ec392b..3d41c47b65 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -18,6 +18,7 @@ #include #include #include +#include #ifdef _WIN32 # include @@ -41,6 +42,7 @@ #include "ggml-quants.h" #include "htp-opnode.h" #include "htp-ops.h" +#include "htp/matmul-ops.h" #include "htp_iface.h" #include "htp-drv.h" @@ -51,7 +53,7 @@ using u32vec = std::vector; static int opt_arch = 0; // autodetect static size_t opt_ndev = 1; static size_t opt_nhvx = 0; // use all -static int opt_use_hmx = 1; // when set, enable HMX; when 0, use HVX only +static int opt_nhmx = 1; // when set, enable HMX; when 0, use HVX only static size_t opt_vmem = HTP_OP_MAX_VMEM_DEFAULT; // max available va space for buffer mappings static size_t opt_mbuf = 1ul * 1024 * 1024 * 1024; // max buffer size static int opt_etm = 0; @@ -59,6 +61,8 @@ static int opt_verbose = 0; static int opt_profile = 0; // profiling mode (0-disabled, 1-basic, 2-pmu) static int opt_hostbuf = 1; // hostbuf ON by default +static int opt_mm_select = 3; // 3 = HMX -> Tiled -> Flat -> CPU, 2 = Tiled -> Flat -> CPU, 1 = Flat -> CPU + // Default PMU events, if profiling with PMU (mode=2) is enabled // See https://docs.qualcomm.com/doc/80-N2040-60/topic/pmu-events.html // https://docs.qualcomm.com/doc/80-N2040-61/topic/hvx-pmu-events.html @@ -68,22 +72,15 @@ static u32vec opt_pmu_evt { 0x3, 0x111, 0x100, 0x105, 0x240, 0x256, 0x7D, 0x8C } static int opt_opstage = HTP_OPSTAGE_QUEUE | HTP_OPSTAGE_COMPUTE; static int opt_opbatch = 1024; // max number of ops in a batch static int opt_opqueue = 16; // max number of pending batches -static int opt_oppoll = 0; // polling for batch completions static int opt_optrace = 0; // trace buffer size per thread (0 means default) +static int opt_oppoll = 0; // polling for batch completions +static int opt_opfusion = 1; // enable/disable op fusion static std::regex* opt_opfilter = NULL; // regex of ops to not claim #define HEX_VERBOSE(...) \ if (opt_verbose) GGML_LOG_DEBUG(__VA_ARGS__) -static inline uint64_t hex_is_aligned(void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_round_up(size_t n, size_t m) { - return m * ((n + m - 1) / m); -} - static const char * status_to_str(uint32_t status) { switch (status) { case HTP_STATUS_OK: @@ -107,15 +104,15 @@ static void ggml_hexagon_dump_op_exec(const std::string &sess_name, const htp_op if (!opt_verbose) return; htp_opformat fmt(node); - GGML_LOG_DEBUG("ggml-hex: %s execute-op %s: %s : %s : %s : %s : %s : flags 0x%x\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, req_flags); + GGML_LOG_DEBUG("ggml-hex: %s execute-op %s|%s|%s|%s|%s|%s|%s|flags 0x%x\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, fmt.kparams, req_flags); } static void ggml_hexagon_dump_op_supp(const std::string &sess_name, const struct ggml_tensor * op, bool supp) { if (!opt_verbose) return; htp_opformat fmt(htp_opformat(htp_opnode{const_cast(op), {}, HTP_OP_INVALID})); - GGML_LOG_DEBUG("ggml-hex: %s supports-op %s: %s : %s : %s : %s : %s : %s\n", sess_name.c_str(), + GGML_LOG_DEBUG("ggml-hex: %s supports-op %s|%s|%s|%s|%s|%s|%s\n", sess_name.c_str(), ggml_op_desc(op), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.buffs, supp ? "yes" : "no"); } @@ -144,16 +141,52 @@ static void ggml_hexagon_dump_op_prof(const std::string &sess_name, const htp_op char pmu_str[256] = ""; if (opt_profile == 2) { static_assert(HTP_PROF_PMU_NCNT == 8, "current implementation assumes 8 PMU counters"); - sprintf(pmu_str, " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(pmu_str, sizeof(pmu_str), " pmu [%u,%u,%u,%u,%u,%u,%u,%u]", pmu[0], pmu[1], pmu[2], pmu[3], pmu[4], pmu[5], pmu[6], pmu[7]); } htp_opformat fmt(node); float mhz = op_usec > 0 ? (float) op_cycles / op_usec : 0.0f; - GGML_LOG_DEBUG("ggml-hex: %s profile-op %s: %s : %s : %s : %s : usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), - node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); + GGML_LOG_DEBUG("ggml-hex: %s profile-op %s|%s|%s|%s|%s|%s|usec %u cycles %u start %u mhz %.1f%s\n", sess_name.c_str(), + node.op_name().c_str(), fmt.names, fmt.dims, fmt.types, fmt.strides, fmt.kparams, op_usec, op_cycles, pd.cycles_start, mhz, pmu_str); } +// ** + +static inline bool ggml_hexagon_is_repack_type(enum ggml_type type) { + return type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || + type == GGML_TYPE_Q8_0 || type == GGML_TYPE_IQ4_NL || + type == GGML_TYPE_MXFP4; +} + +static inline bool ggml_hexagon_is_hmx_weight_type(enum ggml_type type) { + return type == GGML_TYPE_F16 || type == GGML_TYPE_F32 || ggml_hexagon_is_repack_type(type); +} + +struct htp_mm_kernel_params; +struct ggml_hexagon_session; +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + struct htp_mm_kernel_params * kparams +); + // ** backend sessions struct ggml_hexagon_opbatch; @@ -180,6 +213,18 @@ struct ggml_hexagon_session { ggml_backend_buffer_type buffer_type = {}; ggml_backend_buffer_type repack_buffer_type = {}; + uint32_t n_threads = 0; + uint32_t n_hvx = 0; + uint32_t n_hmx = 0; + uint64_t vtcm_size = 0; + size_t max_vmem = 0; + size_t max_bufsize = 0; + + struct { + uint64_t uid = 0; + std::vector htp_nodes; + } cached_graph; + ggml_hexagon_session(int dev_id, ggml_backend_dev_t dev) noexcept(false); ~ggml_hexagon_session() noexcept(true); @@ -325,47 +370,7 @@ static enum ggml_status ggml_backend_hexagon_buffer_init_tensor(ggml_backend_buf return GGML_STATUS_SUCCESS; } -// ======== Q4x4x2 ==================== -struct x2_q4 { - int v[2]; -}; - -static x2_q4 unpack_q4(uint8_t v) { - x2_q4 x = { (int) (v & 0x0f) - 8, (int) (v >> 4) - 8 }; - return x; -} - -static void dump_block_q4_0(const block_q4_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q4_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_q4(b->qs[0]).v[0], - unpack_q4(b->qs[1]).v[0], unpack_q4(b->qs[2]).v[0], unpack_q4(b->qs[3]).v[0], unpack_q4(b->qs[12]).v[1], - unpack_q4(b->qs[13]).v[1], unpack_q4(b->qs[14]).v[1], unpack_q4(b->qs[15]).v[1], - GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q4_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_q4(q[0]).v[0], unpack_q4(q[1]).v[0], unpack_q4(q[2]).v[0], unpack_q4(q[3]).v[0], - unpack_q4(q[60]).v[0], unpack_q4(q[61]).v[0], unpack_q4(q[62]).v[0], unpack_q4(q[63]).v[0], - unpack_q4(q[124]).v[0], unpack_q4(q[125]).v[0], unpack_q4(q[126]).v[0], unpack_q4(q[127]).v[0], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_q4(q[0]).v[1], unpack_q4(q[1]).v[1], unpack_q4(q[2]).v[1], unpack_q4(q[3]).v[1], - unpack_q4(q[60]).v[1], unpack_q4(q[61]).v[1], unpack_q4(q[62]).v[1], unpack_q4(q[63]).v[1], - unpack_q4(q[124]).v[1], unpack_q4(q[125]).v[1], unpack_q4(q[126]).v[1], unpack_q4(q[127]).v[1], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} +// ** Repack helpers for tiled quantized weights static void unpack_q4_0_quants(uint8_t * qs, const block_q4_0 * x, unsigned int bi) { static const int qk = QK4_0; @@ -388,300 +393,6 @@ static void pack_q4_0_quants(block_q4_0 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4x4x2(uint8_t * y, const block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_0_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } -} - -static void unpack_row_q4x4x2(block_q4_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q4x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q4_0(&x[i * 8 + 0], 0); - dump_block_q4_0(&x[i * 8 + 1], 1); - dump_block_q4_0(&x[i * 8 + 2], 2); - dump_block_q4_0(&x[i * 8 + 3], 3); - dump_block_q4_0(&x[i * 8 + 4], 4); - dump_block_q4_0(&x[i * 8 + 5], 5); - dump_block_q4_0(&x[i * 8 + 6], 6); - dump_block_q4_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q4x4x2(block_q4_0 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 8, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_0_quants(&x[i * 8 + 0], qs, 0); - pack_q4_0_quants(&x[i * 8 + 1], qs, 1); - pack_q4_0_quants(&x[i * 8 + 2], qs, 2); - pack_q4_0_quants(&x[i * 8 + 3], qs, 3); - pack_q4_0_quants(&x[i * 8 + 4], qs, 4); - pack_q4_0_quants(&x[i * 8 + 5], qs, 5); - pack_q4_0_quants(&x[i * 8 + 6], qs, 6); - pack_q4_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q4_0 data into q4x4x2 tensor -static void repack_q4_0_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_0-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q4x4x2((block_q4_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q4x4x2((uint8_t *) buf_rp, (const block_q4_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q4x4x2 tensor into q4_0 data -static void repack_q4x4x2_q4_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q4x4x2((block_q4_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - static void unpack_q4_1_quants(uint8_t * qs, const block_q4_1 * x, unsigned int bi) { static const int qk = QK4_1; @@ -703,603 +414,19 @@ static void pack_q4_1_quants(block_q4_1 * x, const uint8_t * qs, unsigned int bi } } -static void repack_row_q4_1x4x2(uint8_t * y, const block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - unpack_q4_1_quants(qs, &x[i * 8 + 0], 0); - unpack_q4_1_quants(qs, &x[i * 8 + 1], 1); - unpack_q4_1_quants(qs, &x[i * 8 + 2], 2); - unpack_q4_1_quants(qs, &x[i * 8 + 3], 3); - unpack_q4_1_quants(qs, &x[i * 8 + 4], 4); - unpack_q4_1_quants(qs, &x[i * 8 + 5], 5); - unpack_q4_1_quants(qs, &x[i * 8 + 6], 6); - unpack_q4_1_quants(qs, &x[i * 8 + 7], 7); - - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales and offsets - for (int i = 0; i < nb; i++) { - ggml_half * d_m = (ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - d_m[j * 2 + 0] = x[i * 8 + j].d; - d_m[j * 2 + 1] = x[i * 8 + j].m; - } - } -} - -static void unpack_row_q4_1x4x2(block_q4_1 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int dblk_size = 8 * 4; // 8x (d, m) __fp16 = 32 bytes - const int qblk_size = qk / 2; // int4 = 128 bytes - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales/offsets - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0x0F; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0x0F; - qs[j+128] = q[j] >> 4; - } - } - - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - // Unpack the scales and offsets - for (int i = 0; i < nb; i++) { - const ggml_half * d_m = (const ggml_half *) (y_d + i * dblk_size); - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = d_m[j * 2 + 0]; - x[i * 8 + j].m = d_m[j * 2 + 1]; - } - } -} - -static void init_row_q4_1x4x2(block_q4_1 * x, int64_t k) { - static const int qk = QK_Q4_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q4_1_quants(&x[i * 8 + 0], qs, 0); - pack_q4_1_quants(&x[i * 8 + 1], qs, 1); - pack_q4_1_quants(&x[i * 8 + 2], qs, 2); - pack_q4_1_quants(&x[i * 8 + 3], qs, 3); - pack_q4_1_quants(&x[i * 8 + 4], qs, 4); - pack_q4_1_quants(&x[i * 8 + 5], qs, 5); - pack_q4_1_quants(&x[i * 8 + 6], qs, 6); - pack_q4_1_quants(&x[i * 8 + 7], qs, 7); - } - - for (int i = 0; i < nb; i++) { - for (int j = 0; j < 8; j++) { - x[i * 8 + j].d = 0; - x[i * 8 + j].m = 0; - } - } -} - -static void repack_q4_1_q4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4_1-q4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - init_row_q4_1x4x2((block_q4_1 *) buf_pd, t->ne[0]); - memcpy(buf_pd, src, n_rem_bytes); - repack_row_q4_1x4x2((uint8_t *) buf_rp, (const block_q4_1 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -static void repack_q4x4x2_q4_1(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q4_0x4x2)); - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) - - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q4x4x2-q4_1 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_rp, 0, row_size_rp); // clear-out padded buffer to make sure the tail is all zeros - - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, row_size); - } - - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_rp, src, row_size); - unpack_row_q4_1x4x2((block_q4_1 *) buf_pd, (const uint8_t *) buf_rp, t->ne[0]); - memcpy(dst, buf_pd, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== Q8x4x2 ==================== -static void dump_block_q8_0(const block_q8_0 * b, int i) { - HEX_VERBOSE("ggml-hex: repack q8_0 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, b->qs[0], b->qs[1], b->qs[2], - b->qs[3], b->qs[28], b->qs[29], b->qs[30], b->qs[31], GGML_FP16_TO_FP32(b->d)); -} - -static void dump_packed_block_q8x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_Q8_0x4x2; - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_d = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const ggml_half * d = (const ggml_half *) (v_d + i * dblk_size); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - q[0], q[1], q[2], q[3], q[60], q[61], q[62], q[63], q[124], q[125], q[126], q[127], - GGML_FP16_TO_FP32(d[0]), GGML_FP16_TO_FP32(d[1]), GGML_FP16_TO_FP32(d[2]), GGML_FP16_TO_FP32(d[3])); - - HEX_VERBOSE("ggml-hex: repack q8x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, q[128], q[129], q[130], q[131], q[192], q[193], q[194], q[195], q[252], q[253], q[254], q[255], - GGML_FP16_TO_FP32(d[4]), GGML_FP16_TO_FP32(d[5]), GGML_FP16_TO_FP32(d[6]), GGML_FP16_TO_FP32(d[7])); -} - -static void unpack_q8_0_quants(uint8_t * qs, const block_q8_0 * x, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - qs[bi * qk + i] = x->qs[i]; - } -} - -static void pack_q8_0_quants(block_q8_0 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK8_0; - - for (unsigned int i = 0; i < qk; ++i) { - x->qs[i] = qs[bi * qk + i]; - } -} - -static void repack_row_q8x4x2(uint8_t * y, const block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - uint8_t * y_q = y + 0; // quants first - uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } - - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - - unpack_q8_0_quants(qs, &x[i * 8 + 0], 0); - unpack_q8_0_quants(qs, &x[i * 8 + 1], 1); - unpack_q8_0_quants(qs, &x[i * 8 + 2], 2); - unpack_q8_0_quants(qs, &x[i * 8 + 3], 3); - unpack_q8_0_quants(qs, &x[i * 8 + 4], 4); - unpack_q8_0_quants(qs, &x[i * 8 + 5], 5); - unpack_q8_0_quants(qs, &x[i * 8 + 6], 6); - unpack_q8_0_quants(qs, &x[i * 8 + 7], 7); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - q[j] = qs[j]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - ggml_half * d = (ggml_half *) (y_d + i * dblk_size); - d[0] = x[i * 8 + 0].d; - d[1] = x[i * 8 + 1].d; - d[2] = x[i * 8 + 2].d; - d[3] = x[i * 8 + 3].d; - d[4] = x[i * 8 + 4].d; - d[5] = x[i * 8 + 5].d; - d[6] = x[i * 8 + 6].d; - d[7] = x[i * 8 + 7].d; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } -} - -static void unpack_row_q8x4x2(block_q8_0 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - const int dblk_size = 8 * 2; // 8x __fp16 - const int qblk_size = qk; // int8 - const int qrow_size = k; // int8 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_d = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_q8x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_Q4_0x4x2]; // unpacked quants - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk; j++) { - qs[j] = q[j]; - } - - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const ggml_half * d = (const ggml_half *) (y_d + i * dblk_size); - x[i * 8 + 0].d = d[0]; - x[i * 8 + 1].d = d[1]; - x[i * 8 + 2].d = d[2]; - x[i * 8 + 3].d = d[3]; - x[i * 8 + 4].d = d[4]; - x[i * 8 + 5].d = d[5]; - x[i * 8 + 6].d = d[6]; - x[i * 8 + 7].d = d[7]; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_q8_0(&x[i * 8 + 0], 0); - dump_block_q8_0(&x[i * 8 + 1], 1); - dump_block_q8_0(&x[i * 8 + 2], 2); - dump_block_q8_0(&x[i * 8 + 3], 3); - dump_block_q8_0(&x[i * 8 + 4], 4); - dump_block_q8_0(&x[i * 8 + 5], 5); - dump_block_q8_0(&x[i * 8 + 6], 6); - dump_block_q8_0(&x[i * 8 + 7], 7); - } - } -} - -static void init_row_q8x4x2(block_q8_0 * x, int64_t k) { - static const int qk = QK_Q8_0x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - - // Init the quants such that they unpack into zeros - uint8_t qs[QK_Q8_0x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); - - for (int i = 0; i < nb; i++) { - pack_q8_0_quants(&x[i * 8 + 0], qs, 0); - pack_q8_0_quants(&x[i * 8 + 1], qs, 1); - pack_q8_0_quants(&x[i * 8 + 2], qs, 2); - pack_q8_0_quants(&x[i * 8 + 3], qs, 3); - pack_q8_0_quants(&x[i * 8 + 4], qs, 4); - pack_q8_0_quants(&x[i * 8 + 5], qs, 5); - pack_q8_0_quants(&x[i * 8 + 6], qs, 6); - pack_q8_0_quants(&x[i * 8 + 7], qs, 7); - } - - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_Q8_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].d = 0; - x[i * 8 + 1].d = 0; - x[i * 8 + 2].d = 0; - x[i * 8 + 3].d = 0; - x[i * 8 + 4].d = 0; - x[i * 8 + 5].d = 0; - x[i * 8 + 6].d = 0; - x[i * 8 + 7].d = 0; - } -} - -// repack q8_0 data into q8x4x2 tensor -static void repack_q8_0_q8x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8_0-q8x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_q8x4x2((block_q8_0 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer - repack_row_q8x4x2((uint8_t *) buf_rp, (const block_q8_0 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// repack q8x4x2 tensor into q8_0 data -static void repack_q8x4x2_q8_0(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); - - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_Q8_0x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size quants + scales) - - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; - - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; - - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); - - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-q8x4x2-q8_0 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, size, - t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); - } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - // We still need to read and unpack the entire source row because quantization is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_q8x4x2((block_q8_0 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - - // But we only copy the remaining number of bytes to the destination. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); -} - -// ======== MXFP4x4x2 ==================== -struct x2_mxfp4 { - int v[2]; -}; - -static x2_mxfp4 unpack_mxfp4(uint8_t v) { - x2_mxfp4 x; - x.v[0] = kvalues_mxfp4[(v & 0x0f)]; - x.v[1] = kvalues_mxfp4[(v >> 4)]; - return x; -} - -static void dump_block_mxfp4(const block_mxfp4 * b, int i) { - HEX_VERBOSE("ggml-hex: repack mxfp4 %d: %d %d %d %d ... %d %d %d %d : %.6f\n", i, unpack_mxfp4(b->qs[0]).v[0], - unpack_mxfp4(b->qs[1]).v[0], unpack_mxfp4(b->qs[2]).v[0], unpack_mxfp4(b->qs[3]).v[0], - unpack_mxfp4(b->qs[12]).v[1], unpack_mxfp4(b->qs[13]).v[1], unpack_mxfp4(b->qs[14]).v[1], - unpack_mxfp4(b->qs[15]).v[1], GGML_E8M0_TO_FP32_HALF(b->e)); -} - -static void dump_packed_block_mxfp4x4x2(const uint8_t * v, unsigned int i, size_t k) { - static const int qk = QK_MXFP4x4x2; - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded) - - const uint8_t * v_q = v + 0; // quants first - const uint8_t * v_e = v + qrow_size; // then scales - - const uint8_t * q = v_q + i * qblk_size; - const uint8_t * e = (const uint8_t *) (v_e + i * eblk_size); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", i, - unpack_mxfp4(q[0]).v[0], unpack_mxfp4(q[1]).v[0], unpack_mxfp4(q[2]).v[0], unpack_mxfp4(q[3]).v[0], - unpack_mxfp4(q[60]).v[0], unpack_mxfp4(q[61]).v[0], unpack_mxfp4(q[62]).v[0], unpack_mxfp4(q[63]).v[0], - unpack_mxfp4(q[124]).v[0], unpack_mxfp4(q[125]).v[0], unpack_mxfp4(q[126]).v[0], - unpack_mxfp4(q[127]).v[0], GGML_E8M0_TO_FP32_HALF(e[0]), GGML_E8M0_TO_FP32_HALF(e[1]), - GGML_E8M0_TO_FP32_HALF(e[2]), GGML_E8M0_TO_FP32_HALF(e[3])); - - HEX_VERBOSE("ggml-hex: repack mxfp4x4x2-%d: %d %d %d %d ... %d %d %d %d ... %d %d %d %d : %.6f %.6f %.6f %.6f\n", - i + 1, unpack_mxfp4(q[0]).v[1], unpack_mxfp4(q[1]).v[1], unpack_mxfp4(q[2]).v[1], - unpack_mxfp4(q[3]).v[1], unpack_mxfp4(q[60]).v[1], unpack_mxfp4(q[61]).v[1], unpack_mxfp4(q[62]).v[1], - unpack_mxfp4(q[63]).v[1], unpack_mxfp4(q[124]).v[1], unpack_mxfp4(q[125]).v[1], - unpack_mxfp4(q[126]).v[1], unpack_mxfp4(q[127]).v[1], GGML_E8M0_TO_FP32_HALF(e[4]), - GGML_E8M0_TO_FP32_HALF(e[5]), GGML_E8M0_TO_FP32_HALF(e[6]), GGML_E8M0_TO_FP32_HALF(e[7])); -} - static void unpack_mxfp4_quants(uint8_t * qs, const block_mxfp4 * x, unsigned int bi) { static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { - const uint8_t x0 = (x->qs[i] & 0x0F); - const uint8_t x1 = (x->qs[i] >> 4); + const int x0 = (x->qs[i] & 0x0F); + const int x1 = (x->qs[i] >> 4); qs[bi * qk + i + 0] = x0; qs[bi * qk + i + qk / 2] = x1; } } static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int bi) { - static const int qk = QK4_0; + static const int qk = QK_MXFP4; for (unsigned int i = 0; i < qk / 2; ++i) { const uint8_t x0 = qs[bi * qk + i + 0]; @@ -1308,299 +435,419 @@ static void pack_mxfp4_quants(block_mxfp4 * x, const uint8_t * qs, unsigned int } } -static void repack_row_mxfp4x4x2(uint8_t * y, const block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers +// repack q4_0 data into q4_0_tiled tensor +static void repack_q4_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_0 * src_matrix = (const block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - uint8_t * y_q = y + 0; // quants first - uint8_t * y_e = y + qrow_size; // then scales + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); - } - } + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; - // Repack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_0_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 8, 32); + } + } - unpack_mxfp4_quants(qs, &x[i * 8 + 0], 0); - unpack_mxfp4_quants(qs, &x[i * 8 + 1], 1); - unpack_mxfp4_quants(qs, &x[i * 8 + 2], 2); - unpack_mxfp4_quants(qs, &x[i * 8 + 3], 3); - unpack_mxfp4_quants(qs, &x[i * 8 + 4], 4); - unpack_mxfp4_quants(qs, &x[i * 8 + 5], 5); - unpack_mxfp4_quants(qs, &x[i * 8 + 6], 6); - unpack_mxfp4_quants(qs, &x[i * 8 + 7], 7); + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } - bool partial = (nloe && i == nb-1); - - uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - q[j] = partial ? (qs[j*2+1] << 4) | qs[j*2+0] : (qs[j+128] << 4) | qs[j+000]; - } - } - - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Repack the scales - uint8_t * e = (uint8_t *) (y_e + i * eblk_size); - e[0] = x[i * 8 + 0].e; - e[1] = x[i * 8 + 1].e; - e[2] = x[i * 8 + 2].e; - e[3] = x[i * 8 + 3].e; - e[4] = x[i * 8 + 4].e; - e[5] = x[i * 8 + 5].e; - e[6] = x[i * 8 + 6].e; - e[7] = x[i * 8 + 7].e; - } - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); - } - } -} - -static void unpack_row_mxfp4x4x2(block_mxfp4 * x, const uint8_t * y, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) - const int nloe = k % qk; // leftovers - - const int eblk_size = 8 * 1; // 8x E8M0 - const int qblk_size = qk / 2; // int4 - const int qrow_size = k / 2; // int4 (not padded to blocks) - - const uint8_t * y_q = y + 0; // quants first - const uint8_t * y_e = y + qrow_size; // then scales - - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_packed_block_mxfp4x4x2(y, i, k); - } - } - - // Unpack the quants - for (int i = 0; i < nb; i++) { - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - - bool partial = (nloe && i == nb-1); - - const uint8_t * q = y_q + (i * qblk_size); - for (int j = 0; j < qk / 2; j++) { - if (partial) { - qs[j*2+0] = q[j] & 0xf; - qs[j*2+1] = q[j] >> 4; - } else { - qs[j+000] = q[j] & 0xf; - qs[j+128] = q[j] >> 4; + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } } } - - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); } +} - // Repack the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4_0x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - const uint8_t * e = (const uint8_t *) (y_e + i * eblk_size); - x[i * 8 + 0].e = e[0]; - x[i * 8 + 1].e = e[1]; - x[i * 8 + 2].e = e[2]; - x[i * 8 + 3].e = e[3]; - x[i * 8 + 4].e = e[4]; - x[i * 8 + 5].e = e[5]; - x[i * 8 + 6].e = e[6]; - x[i * 8 + 7].e = e[7]; - } +// repack q4_0_tiled tensor into q4_0 data +static void repack_tiled_q4_0(void * data, const ggml_tensor * t, size_t size) { + block_q4_0 * dst_matrix = (block_q4_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - if (opt_verbose > 2) { - for (int i = 0; i < nb; i++) { - dump_block_mxfp4(&x[i * 8 + 0], 0); - dump_block_mxfp4(&x[i * 8 + 1], 1); - dump_block_mxfp4(&x[i * 8 + 2], 2); - dump_block_mxfp4(&x[i * 8 + 3], 3); - dump_block_mxfp4(&x[i * 8 + 4], 4); - dump_block_mxfp4(&x[i * 8 + 5], 5); - dump_block_mxfp4(&x[i * 8 + 6], 6); - dump_block_mxfp4(&x[i * 8 + 7], 7); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } + + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_0_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } + + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } } } } -static void init_row_mxfp4x4x2(block_mxfp4 * x, int64_t k) { - static const int qk = QK_MXFP4x4x2; - const int nb = (k + qk - 1) / qk; // number of blocks (padded) +// repack q4_1 data into q4_1_tiled tensor +static void repack_q4_1_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q4_1 * src_matrix = (const block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - // Init the quants such that they unpack into zeros - uint8_t qs[QK_MXFP4x4x2]; // unpacked quants - memset(qs, 0, sizeof(qs)); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - for (int i = 0; i < nb; i++) { - pack_mxfp4_quants(&x[i * 8 + 0], qs, 0); - pack_mxfp4_quants(&x[i * 8 + 1], qs, 1); - pack_mxfp4_quants(&x[i * 8 + 2], qs, 2); - pack_mxfp4_quants(&x[i * 8 + 3], qs, 3); - pack_mxfp4_quants(&x[i * 8 + 4], qs, 4); - pack_mxfp4_quants(&x[i * 8 + 5], qs, 5); - pack_mxfp4_quants(&x[i * 8 + 6], qs, 6); - pack_mxfp4_quants(&x[i * 8 + 7], qs, 7); - } + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q4_1 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Init the scales - // Note: Do not combine with the loop above. For tensor sizes not multiple of 256 (QK_MXFP4x4x2) - // the last block is truncated and overridden by the scales. - for (int i = 0; i < nb; i++) { - // Unpack the scales - x[i * 8 + 0].e = 0; - x[i * 8 + 1].e = 0; - x[i * 8 + 2].e = 0; - x[i * 8 + 3].e = 0; - x[i * 8 + 4].e = 0; - x[i * 8 + 5].e = 0; - x[i * 8 + 6].e = 0; - x[i * 8 + 7].e = 0; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_q4_1_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } + + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } + + ggml_half * scale_dst = (ggml_half *)(tile_dst + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + scale_dst[2 * row + 0] = src_expert[r * (ne0 / 32) + kt].d; + scale_dst[2 * row + 1] = src_expert[r * (ne0 / 32) + kt].m; + } else { + scale_dst[2 * row + 0] = 0; + scale_dst[2 * row + 1] = 0; + } + } + } + } + } } } -// repack mxfp4 data into mxfp4x4x2 tensor -static void repack_mxfp4_mxfp4x4x2(ggml_tensor * t, const void * data, size_t size) { - int64_t nrows = ggml_nrows(t); +// repack q4_1_tiled tensor into q4_1 data +static void repack_tiled_q4_1(void * data, const ggml_tensor * t, size_t size) { + block_q4_1 * dst_matrix = (block_q4_1 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // Ensure we don't try to read more data than is available in the source buffer 'data' - // or write more than the tensor can hold. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q4_1 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_q4_1_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } - HEX_VERBOSE("ggml-hex: repack-mxfp4-mxfp4x4x2 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); - - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); // init padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - memcpy(buf_pd, src, row_size); - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); + const ggml_half * scale_src = (const ggml_half *)(tile_src + 512); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[2 * row]; + dst_expert[r * (ne0 / 32) + kt].m = scale_src[2 * row + 1]; + } + } + } + } + } } - - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) data + (i * row_size); - uint8_t * dst = (uint8_t *) t->data + (i * row_size); - - // re-init the row because we are potentially copying a partial row - init_row_mxfp4x4x2((block_mxfp4 *) buf_pd, t->ne[0]); - - // Copy only the remaining bytes from the source. - memcpy(buf_pd, src, n_rem_bytes); - - // Repack the entire buffer (partial data + zero padding). - repack_row_mxfp4x4x2((uint8_t *) buf_rp, (const block_mxfp4 *) buf_pd, t->ne[0]); - - // Write only the corresponding remaining bytes to the destination tensor. - memcpy(dst, buf_rp, n_rem_bytes); - } - - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); } -// repack mxfp4x4x2 tensor into mxfp4 data -static void repack_mxfp4x4x2_mxfp4(void * data, const ggml_tensor * t, size_t size) { - int64_t nrows = ggml_nrows(t); +// repack q8_0 data into q8_0_tiled tensor +static void repack_q8_0_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_q8_0 * src_matrix = (const block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - size_t row_size = ggml_row_size(t->type, t->ne[0]); - size_t row_size_pd = ggml_row_size(t->type, hex_round_up(t->ne[0], QK_MXFP4x4x2)); // extra elements for the pad - size_t row_size_rp = row_size_pd; // scratch must hold one full padded tile (qblk_size/2 quants + scales) + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // Ensure we don't try to copy more data than the tensor actually contains. - const size_t total_tensor_size = (size_t)nrows * row_size; - const size_t n_bytes_to_copy = size < total_tensor_size ? size : total_tensor_size; + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_q8_0 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; - // Calculate how many full rows and how many remaining bytes we need to process. - const int64_t n_full_rows = n_bytes_to_copy / row_size; - const size_t n_rem_bytes = n_bytes_to_copy % row_size; + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; - void * buf_pd = ggml_aligned_malloc(row_size_pd); - GGML_ASSERT(buf_pd != NULL); + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + const block_q8_0 * b = (r < ne1 && kt < ne0 / 32) ? &src_expert[r * (ne0 / 32) + kt] : NULL; + tile_dst[cp * 64 + 2 * row + 0] = b ? b->qs[col0] : 0; + tile_dst[cp * 64 + 2 * row + 1] = b ? b->qs[col1] : 0; + } + } - void * buf_rp = ggml_aligned_malloc(row_size_rp); - GGML_ASSERT(buf_rp != NULL); - - HEX_VERBOSE("ggml-hex: repack-mxfp4x4x2-mxfp4 %s : data %p size %zu dims %ldx%ld row-size %zu\n", t->name, data, - size, t->ne[0], nrows, row_size); - - memset(buf_pd, 0, row_size_pd); // clear-out padded buffer to make sure the tail is all zeros - - // 1. Process all the full rows - for (int64_t i = 0; i < n_full_rows; i++) { - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); - - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); - memcpy(dst, buf_rp, row_size); + ggml_half * scale_dst = (ggml_half *)(tile_dst + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].d : 0; + } + } + } + } } +} - // 2. Process the final, potentially partial, row - if (n_rem_bytes > 0) { - const int64_t i = n_full_rows; - const uint8_t * src = (const uint8_t *) t->data + (i * row_size); - uint8_t * dst = (uint8_t *) data + (i * row_size); +// repack q8_0_tiled tensor into q8_0 data +static void repack_tiled_q8_0(void * data, const ggml_tensor * t, size_t size) { + block_q8_0 * dst_matrix = (block_q8_0 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); - // We still need to read and unpack the entire source row because the format is block-based. - memcpy(buf_pd, src, row_size); - unpack_row_mxfp4x4x2((block_mxfp4 *) buf_rp, (const uint8_t *) buf_pd, t->ne[0]); + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; - // But we only copy the remaining number of bytes to the destination to respect the size limit. - memcpy(dst, buf_rp, n_rem_bytes); + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_q8_0 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + for (int cp = 0; cp < 16; cp++) { + int col0 = cp * 2; + int col1 = col0 + 1; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + block_q8_0 & b = dst_expert[r * (ne0 / 32) + kt]; + b.qs[col0] = tile_src[cp * 64 + 2 * row + 0]; + b.qs[col1] = tile_src[cp * 64 + 2 * row + 1]; + } + } + } + + const ggml_half * scale_src = (const ggml_half *)(tile_src + 1024); + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].d = scale_src[row]; + } + } + } + } + } } +} - ggml_aligned_free(buf_pd, row_size_pd); - ggml_aligned_free(buf_rp, row_size_rp); +// repack mxfp4 data into mxfp4_tiled tensor +static void repack_mxfp4_tiled(ggml_tensor * t, const void * data, size_t size) { + const block_mxfp4 * src_matrix = (const block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + const block_mxfp4 * src_expert = src_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + uint8_t * matrix_dst = (uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + uint8_t * tile_dst = matrix_dst + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + unpack_mxfp4_quants(tile_quants[row], &src_expert[r * (ne0 / 32) + kt], 0); + } else { + memset(tile_quants[row], 0, 32); + } + } + + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + tile_dst[cp * 32 + row] = (tile_quants[row][2 * cp + 1] << 4) | tile_quants[row][2 * cp]; + } + } + + uint8_t * scale_dst = tile_dst + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + scale_dst[row] = (r < ne1 && kt < ne0 / 32) ? src_expert[r * (ne0 / 32) + kt].e : 0; + } + } + } + } + } +} + +// repack mxfp4_tiled tensor into mxfp4 data +static void repack_tiled_mxfp4(void * data, const ggml_tensor * t, size_t size) { + block_mxfp4 * dst_matrix = (block_mxfp4 *) data; + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + int64_t ne0_padded = hex_round_up(ne0, 32); + int64_t ne1_padded = hex_round_up(ne1, 32); + + int n_col_tiles = ne1_padded / 32; + int n_k_tiles = ne0_padded / 32; + const size_t tile_size = HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + const size_t matrix_size = n_col_tiles * n_k_tiles * tile_size; + + for (int i3 = 0; i3 < ne3; i3++) { + for (int i2 = 0; i2 < ne2; i2++) { + block_mxfp4 * dst_expert = dst_matrix + (i3 * ne2 + i2) * (ne1 * (ne0 / 32)); + const uint8_t * matrix_src = (const uint8_t *) t->data + (i3 * ne2 + i2) * matrix_size; + + for (int ct = 0; ct < n_col_tiles; ct++) { + for (int kt = 0; kt < n_k_tiles; kt++) { + const uint8_t * tile_src = matrix_src + (ct * n_k_tiles + kt) * tile_size; + + uint8_t tile_quants[32][32]; + for (int cp = 0; cp < 16; cp++) { + for (int row = 0; row < 32; row++) { + uint8_t val = tile_src[cp * 32 + row]; + tile_quants[row][2 * cp + 0] = val & 0x0F; + tile_quants[row][2 * cp + 1] = val >> 4; + } + } + + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + pack_mxfp4_quants(&dst_expert[r * (ne0 / 32) + kt], tile_quants[row], 0); + } + } + + const uint8_t * scale_src = tile_src + 512; + for (int row = 0; row < 32; row++) { + int64_t r = ct * 32 + row; + if (r < ne1 && kt < ne0 / 32) { + dst_expert[r * (ne0 / 32) + kt].e = scale_src[row]; + } + } + } + } + } + } } static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, @@ -1617,32 +864,32 @@ static void ggml_backend_hexagon_buffer_set_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4_1_q4x4x2(tensor, data, size); + repack_q4_1_tiled(tensor, data, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8_0_q8x4x2(tensor, data, size); + repack_q8_0_tiled(tensor, data, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); // IQ4_NL has identical block layout to Q4_0 (ggml_half d + uint8_t qs[16]) - repack_q4_0_q4x4x2(tensor, data, size); + repack_q4_0_tiled(tensor, data, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4_mxfp4x4x2(tensor, data, size); + repack_mxfp4_tiled(tensor, data, size); break; default: @@ -1665,31 +912,31 @@ static void ggml_backend_hexagon_buffer_get_tensor(ggml_backend_buffer_t buffer, case GGML_TYPE_Q4_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_Q4_1: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_1(data, tensor, size); + repack_tiled_q4_1(data, tensor, size); break; case GGML_TYPE_Q8_0: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q8x4x2_q8_0(data, tensor, size); + repack_tiled_q8_0(data, tensor, size); break; case GGML_TYPE_IQ4_NL: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_q4x4x2_q4_0(data, tensor, size); + repack_tiled_q4_0(data, tensor, size); break; case GGML_TYPE_MXFP4: GGML_ASSERT(offset == 0); GGML_ASSERT(offset + size <= ggml_nbytes(tensor)); - repack_mxfp4x4x2_mxfp4(data, tensor, size); + repack_tiled_mxfp4(data, tensor, size); break; default: @@ -1767,12 +1014,19 @@ static size_t ggml_backend_hexagon_buffer_type_get_alignment(ggml_backend_buffer } static size_t ggml_backend_hexagon_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * t) { + if (t->type == GGML_TYPE_Q4_0 || t->type == GGML_TYPE_Q4_1 || t->type == GGML_TYPE_Q8_0 || t->type == GGML_TYPE_IQ4_NL || t->type == GGML_TYPE_MXFP4) { + int64_t ne0 = hex_round_up(t->ne[0], 32); + int64_t ne1 = hex_round_up(t->ne[1], 32); + int64_t ne2 = t->ne[2]; + int64_t ne3 = t->ne[3]; + return ggml_row_size(t->type, ne0) * ne1 * ne2 * ne3; + } return ggml_nbytes(t); } static size_t ggml_backend_hexagon_buffer_type_get_max_size(ggml_backend_buffer_type_t buffer_type) { - return opt_mbuf; // typically 1GB per buffer - GGML_UNUSED(buffer_type); + auto * context = static_cast(buffer_type->context); + return context->sess->max_bufsize; } static bool ggml_backend_hexagon_buffer_type_is_host(ggml_backend_buffer_type_t buft) { @@ -1803,6 +1057,17 @@ static ggml_backend_buffer_type_i ggml_backend_hexagon_repack_buffer_type_interf /* .is_host = */ ggml_backend_hexagon_repack_buffer_type_is_host, }; +static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { + return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; +} + +static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { + if (!opt_hostbuf) { + return ggml_backend_buffer_is_hexagon(b); + } + return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; +} + struct ggml_hexagon_opbatch { ggml_hexagon_session* sess; @@ -1883,14 +1148,25 @@ struct ggml_hexagon_opbatch { b_vmem += b.size; - HEX_VERBOSE("ggml-hex: add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); + HEX_VERBOSE("ggml-hex: %s add-buffer #%u : fd %d base %p size %zu : vmem %zu\n", sess->c_name(), bi, b.fd, (void*) sbuf->base, (size_t) b.size, b_vmem); return bi; } bool same_shape(const htp_tensor * h, const ggml_tensor * t) const { - return (h->ne[0] == t->ne[0]) && (h->ne[1] == t->ne[1]) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && - (h->nb[0] == t->nb[0]) && (h->nb[1] == t->nb[1]) && (h->nb[2] == t->nb[2]) && (h->nb[3] == t->nb[3]); + int64_t ne0 = t->ne[0]; + int64_t ne1 = t->ne[1]; + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + ne0 = hex_round_up(ne0, 32); + ne1 = hex_round_up(ne1, 32); + } + int64_t nb1 = is_repack ? ggml_row_size(t->type, ne0) : t->nb[1]; + int64_t nb2 = is_repack ? nb1 * ne1 : t->nb[2]; + int64_t nb3 = is_repack ? nb2 * t->ne[2] : t->nb[3]; + + return (h->ne[0] == ne0) && (h->ne[1] == ne1) && (h->ne[2] == t->ne[2]) && (h->ne[3] == t->ne[3]) && + (h->nb[0] == t->nb[0]) && (h->nb[1] == nb1) && (h->nb[2] == nb2) && (h->nb[3] == nb3); } // add tensor and return its index @@ -1921,19 +1197,35 @@ struct ggml_hexagon_opbatch { htp_tensor &h = h_tens[ti]; h.bi = add_buffer(sbuf); h.data = t_offset; - h.size = t_size; h.type = t->type; - h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; - h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + + const bool is_repack = ggml_backend_buffer_is_hexagon_repack(t->buffer) && ggml_hexagon_is_repack_type(t->type); + if (is_repack) { + h.ne[0] = hex_round_up(t->ne[0], 32); + h.ne[1] = hex_round_up(t->ne[1], 32); + h.ne[2] = t->ne[2]; + h.ne[3] = t->ne[3]; + + h.nb[0] = t->nb[0]; + h.nb[1] = ggml_row_size(t->type, h.ne[0]); + h.nb[2] = h.nb[1] * h.ne[1]; + h.nb[3] = h.nb[2] * h.ne[2]; + h.size = h.nb[3] * h.ne[3]; + t_size = h.size; + } else { + h.size = t_size; + h.ne[0] = t->ne[0]; h.ne[1] = t->ne[1]; h.ne[2] = t->ne[2]; h.ne[3] = t->ne[3]; + h.nb[0] = t->nb[0]; h.nb[1] = t->nb[1]; h.nb[2] = t->nb[2]; h.nb[3] = t->nb[3]; + } h.flags = 0; if (ggml_backend_buffer_get_usage(t->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { h.flags |= HTP_TENSOR_COMPUTE; } - HEX_VERBOSE("ggml-hex: add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", + HEX_VERBOSE("ggml-hex: %s add-tensor #%u %s : bi %d data %p offset %zu size %zu flags 0x%x : %zu:%zu:%zu:%zu\n", sess->c_name(), ti, t->name, h.bi, (void*) t->data, (size_t) t_offset, t_size, h.flags, - (size_t) t->ne[0], (size_t) t->ne[1], (size_t) t->ne[2], (size_t) t->ne[3]); + (size_t) h.ne[0], (size_t) h.ne[1], (size_t) h.ne[2], (size_t) h.ne[3]); return ti; } @@ -1962,7 +1254,9 @@ struct ggml_hexagon_opbatch { for (const auto * src : node.get_inputs()) { fit_tensor(src); } - fit_tensor(node.dst()); + for (const auto * output : node.get_outputs()) { + fit_tensor(output); + } if ((extra_bufs + n_bufs) > n_bufs_max) return false; if ((extra_tens + n_tens) > n_tens_max) return false; @@ -1981,7 +1275,8 @@ struct ggml_hexagon_opbatch { ops[n] = node; htp_op_desc &o = h_ops[n]; - memcpy(&o.params, &node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.params, node.node->op_params, sizeof(node.node->op_params)); + memcpy(o.kernel_params, node.kernel_params, sizeof(o.kernel_params)); o.opcode = node.opcode; o.flags = 0; @@ -1989,13 +1284,17 @@ struct ggml_hexagon_opbatch { o.flags |= HTP_OPFLAGS_SKIP_COMPUTE; } - ggml_hexagon_dump_op_exec(sess->c_name(), node, o.flags); + ggml_hexagon_dump_op_exec(sess->c_name(), ops[n], o.flags); auto inputs = node.get_inputs(); for (unsigned int i=0; i < HTP_OP_MAX_INPUTS; i++) { - o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + o.src[i] = (i < inputs.size() && inputs[i]) ? add_tensor(inputs[i]) : 0xffff; + } + + auto outputs = node.get_outputs(); + for (unsigned int i=0; i < HTP_OP_MAX_OUTPUTS; i++) { + o.dst[i] = (i < outputs.size() && outputs[i]) ? add_tensor(outputs[i]) : 0xffff; } - o.dst = add_tensor(node.dst()); } }; @@ -2006,14 +1305,14 @@ struct ggml_hexagon_opqueue { using opvec = std::vector; - std::queue done; // completed batch ids - std::vector op_cache; // per batch op cache - std::vector start_usec; // per batch start time + std::queue done; // completed batch ids + std::vector op_cache; // per batch op cache + std::vector start_usec; // per batch start time ggml_hexagon_opqueue(ggml_hexagon_session *sess, size_t batch_size, size_t depth) { size_t n_bufs = HTP_OP_MAX_BUFS; size_t n_ops = batch_size; - size_t n_tensors = n_ops + n_ops * HTP_OP_MAX_INPUTS; + size_t n_tensors = n_ops * HTP_OP_MAX_OUTPUTS + n_ops * HTP_OP_MAX_INPUTS; size_t tr_size = 0; if (opt_profile == 3) { @@ -2200,7 +1499,7 @@ struct ggml_hexagon_opqueue { char evt_str[256] = ""; if (opt_profile == 3) { - sprintf(evt_str, " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", + snprintf(evt_str, sizeof(evt_str), " evt [%u,%u,%u,%u,%u,%u,%u,%u,%u,%u,%u]", rsp.n_traces[0], rsp.n_traces[1], rsp.n_traces[2], rsp.n_traces[3], rsp.n_traces[4], rsp.n_traces[5], rsp.n_traces[6], rsp.n_traces[7], rsp.n_traces[8], rsp.n_traces[9], rsp.n_traces[10]); @@ -2224,6 +1523,7 @@ void ggml_hexagon_session::flush_pending(bool all) { // Read response packet from queue const uint32_t timeo = opt_oppoll ? 0 : DSPQUEUE_TIMEOUT; + int err = dspqueue_read(this->queue, &flags, 1, &n_dbufs, &dbuf, sizeof(rsp), &rsp_size, (uint8_t *) &rsp, timeo); if (err == AEE_EEXPIRED) { continue; @@ -2404,6 +1704,31 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { this->valid_handle = true; + // Query HW info and resolve session options + this->max_bufsize = opt_mbuf; + { + unsigned int hw_n_threads = 0; + unsigned int hw_n_hvx = 0; + unsigned int hw_n_hmx = 0; + unsigned long long hw_vtcm_size = 0; + int hw_err = htp_iface_hwinfo(this->handle, &hw_n_threads, &hw_n_hvx, &hw_n_hmx, &hw_vtcm_size); + if (hw_err == 0) { + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_threads; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : (uint32_t)hw_n_hvx; + this->n_hmx = (opt_nhmx != 0) ? (uint32_t)hw_n_hmx : 0; + this->vtcm_size = (uint64_t)hw_vtcm_size; + GGML_LOG_INFO("ggml-hex: %s hwinfo: threads %u, hvx %u, hmx %u, vtcm %llu MB\n", + this->c_name(), this->n_threads, this->n_hvx, this->n_hmx, + (unsigned long long)(this->vtcm_size / (1024 * 1024))); + } else { + GGML_LOG_WARN("ggml-hex: %s failed to query hwinfo (0x%x), using defaults\n", this->c_name(), hw_err); + this->n_threads = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hvx = opt_nhvx > 0 ? (uint32_t)opt_nhvx : 8; + this->n_hmx = (opt_nhmx != 0) ? 1 : 0; + this->vtcm_size = 8 * 1024 * 1024; + } + } + // Enable FastRPC QoS mode { struct remote_rpc_control_latency l; @@ -2468,11 +1793,12 @@ void ggml_hexagon_session::allocate(int dev_id) noexcept(false) { opt_vmem = ggml_hexagon_measure_max_vmem(this); GGML_LOG_INFO("ggml-hex: %s measured max vmem %zu\n", this->c_name(), opt_vmem); } + this->max_vmem = opt_vmem; - this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, opt_vmem); + this->op_batch = new ggml_hexagon_opbatch(this, opt_opbatch, this->max_vmem); // Start dspqueue/opbatch processing - err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_use_hmx, opt_vmem); + err = htp_iface_start(this->handle, dev_id, this->queue_id, opt_nhvx, opt_nhmx, this->max_vmem); if (err != 0) { GGML_LOG_ERROR("ggml-hex: %s failed to start session: 0x%08x\n", this->c_name(), (unsigned) err); throw std::runtime_error("ggml-hex: iface start failed (see log for details)"); @@ -2553,16 +1879,6 @@ ggml_hexagon_session::~ggml_hexagon_session() noexcept(true) { // ** backend interface -static bool ggml_backend_buffer_is_hexagon(const struct ggml_backend_buffer * b) { - return b->buft->iface.get_alignment == ggml_backend_hexagon_buffer_type_get_alignment; -} - -static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backend_buffer * b) { - if (!opt_hostbuf) { - return ggml_backend_buffer_is_hexagon(b); - } - return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; -} static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2653,6 +1969,640 @@ static bool ggml_hexagon_supported_gated_delta_net(const struct ggml_hexagon_ses return true; } +static bool ggml_hexagon_matmul_is_hmx_eligible( + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int ne01_padded, + bool is_matmul_id, + bool is_batched +) { + const int ne00 = src0->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int wtype = src0->type; + + // HMX weight tile requires N to be 32-aligned. + if (ne01_padded % 32 != 0) { + return false; + } + + // HMX supports F16, F32, and repack quantized types. + if (!ggml_hexagon_is_hmx_weight_type((ggml_type) wtype)) { + return false; + } + + // HMX paths require K aligned to 32. + if (ne00 % 32 != 0) { + return false; + } + + // Quantized HMX kernels only handle flat 2D matmul (or matmul_id wrapping flat 2D matmuls). + if (!is_matmul_id && is_batched && wtype != GGML_TYPE_F16) { + return false; + } + + // HMX assumes contiguous row-major layout. + if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { + return false; + } + + // M alignment: Use HMX when M > HTP_MM_HMX_MIN_NROWS + const int m = is_matmul_id ? ne12 : ne11; + if (m <= HTP_MM_HMX_MIN_NROWS) { + return false; + } + + return true; +} + +static bool ggml_hexagon_precompute_hmx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne00_padded, + int ne01_padded, + int ne02, + int ne11, + int ne12, + int ne11_padded, + bool is_matmul_id, + bool is_batched, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + const int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const bool pipeline = is_matmul_id ? false : htp_mm_hmx_pipeline(ne11); + const int n_threads = (int)sess->n_threads; + const int ne10 = src1->ne[0]; + + const bool is_batched_val = is_matmul_id ? false : is_batched; + const int group_size = (ne02 > 0 ? ne12 / ne02 : 1); + + size_t m_chunk = 0; + size_t n_chunk = 0; + size_t vtcm_size = 0; + bool use_grouped = false; + int act_threads_selected = 0; + + if (is_batched_val && wtype == GGML_TYPE_F16 && group_size > 1) { + // Try grouped path first + const bool use_dma_activation = (src1->nb[1]/sizeof(float) > (size_t)ne00_padded); + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t f32_scratch_size = use_dma_activation ? hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + size_t group_overhead = 256 + f32_scratch_size; + size_t group_size_per_n, group_size_per_m, group_size_per_mn; + htp_mm_hmx_get_batched_chunk_costs(ne00_padded, group_size, &group_size_per_n, &group_size_per_m, &group_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, group_overhead, group_size_per_n, group_size_per_m, group_size_per_mn, hex_align_up(ne11, 32), ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) ne11 * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_batched_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, group_size, use_dma_activation, pipeline, act_threads); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) ne11 + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + use_grouped = true; + } + } + + if (!use_grouped) { + // Fallback to simple 2D path (group_size = 1) + size_t best_mblocks = SIZE_MAX; + int best_act_threads = 0; + size_t best_m_chunk = 0; + size_t best_n_chunk = 0; + size_t best_vtcm_size = 0; + + // For MUL_MAT_ID the kernel runs one 2D matmul per expert, with M equal to the number of rows routed to that expert. + // A single expert can receive up to all routed rows (dst->ne[1]*dst->ne[2] = n_expert_used*n_tokens), so size the chunk + // search for that upper bound rather than ne12 (token positions only). + // We recompute m_chunk per expert against the actual count in the NPU kernel. + const int m_id_rows = (int) ((size_t) dst->ne[1] * dst->ne[2]); + const int m_for_chunks = is_matmul_id ? hex_align_up(m_id_rows, 32) : ne11_padded; + const int m_for_cost = is_matmul_id ? m_id_rows : ne11; + + int act_threads = n_threads; + while (act_threads >= 1) { + const size_t act_f32_size = is_matmul_id ? 0 : hex_align_up(act_threads * HTP_MM_DMA_ACT_MULTIPLIER * ne00_padded * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t simple_2d_overhead = 256 + act_f32_size; + size_t simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn; + htp_mm_hmx_get_2d_chunk_costs(wtype, ne00_padded, pipeline, aligned_tile_size, &simple_2d_size_per_n, &simple_2d_size_per_m, &simple_2d_size_per_mn); + + size_t m_chunk_candidate = 0; + size_t n_chunk_candidate = 0; + size_t vtcm_size_candidate = 0; + + if (htp_mm_hmx_compute_chunks(vtcm_budget, simple_2d_overhead, simple_2d_size_per_n, simple_2d_size_per_m, simple_2d_size_per_mn, m_for_chunks, ne01_padded, + (size_t) ne01_padded * HTP_MM_HMX_COST_W_DEQUANT, (size_t) m_for_cost * HTP_MM_HMX_COST_A_CONVERT, + &m_chunk_candidate, &n_chunk_candidate, &vtcm_size_candidate) == 0) { + size_t exact_size = htp_mm_hmx_get_2d_vtcm_size(wtype, ne00_padded, m_chunk_candidate, n_chunk_candidate, pipeline, is_matmul_id ? 0 : act_threads, aligned_tile_size); + if (exact_size <= vtcm_budget) { + size_t mblocks = ((size_t) m_for_cost + m_chunk_candidate - 1) / m_chunk_candidate; + if (mblocks < best_mblocks || (mblocks == best_mblocks && act_threads > best_act_threads)) { + best_mblocks = mblocks; + best_act_threads = act_threads; + best_m_chunk = m_chunk_candidate; + best_n_chunk = n_chunk_candidate; + best_vtcm_size = exact_size; + } + } + } + if (act_threads == 1) { + act_threads = 0; + } else { + act_threads /= 2; + } + } + + if (best_act_threads > 0) { + m_chunk = best_m_chunk; + n_chunk = best_n_chunk; + vtcm_size = best_vtcm_size; + act_threads_selected = best_act_threads; + } else { + return false; + } + } + + kparams->n_hmx = 1; + kparams->pipeline = pipeline ? 1 : 0; + kparams->m_chunk = m_chunk; + kparams->n_chunk = n_chunk; + kparams->n_threads = n_threads; + kparams->n_act_threads = act_threads_selected; + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = aligned_tile_size; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = 0; + kparams->vtcm_src1_size = 0; + kparams->vtcm_dst_size = 0; + + if (is_batched && !is_matmul_id) { + kparams->kernel_type = HTP_MM_KERNEL_HMX_F16_BATCHED; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HMX_2D; + } + return true; +} + +static void ggml_hexagon_precompute_hvx_mm_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + int wtype, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + bool is_matmul_id, + size_t vtcm_budget, + struct htp_mm_kernel_params * kparams +) { + kparams->n_hmx = 0; + + const bool is_quant = (wtype != GGML_TYPE_F16 && wtype != GGML_TYPE_F32); + const int src1_nrows = ne11 * ne12 * ne13; + + if (is_quant) { + // Quantized HVX + kparams->tile_size = htp_mm_get_weight_tile_size(wtype); + kparams->aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + + const bool k_align = (ne10 % 32 == 0); + + if (is_matmul_id) { + kparams->kernel_type = (src1_nrows < (int) sess->n_threads) ? HTP_MM_KERNEL_HVX_QUANT_BLOCK : HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0; + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], d, + &vtcm_src0_size, &vtcm_src1_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_id_get_vtcm_sizes( + wtype, ne10, src1_nrows, sess->n_threads, src0->nb[1], 2, + &vtcm_src0_size, &vtcm_src1_size + ); + } + kparams->n_prefetch = best_n_prefetch; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = 0; + } else { + bool try_tiled = (k_align && opt_mm_select >= 2); + if (try_tiled) { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + if (src1_nrows < (int)sess->n_threads) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_BLOCK; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + } + + uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + uint32_t best_n_prefetch = 2; + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = 0; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], d, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + if (total_size <= vtcm_budget) { + best_n_prefetch = d; + break; + } + } + if (best_n_prefetch == 2 && total_size > vtcm_budget) { + total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 2, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + } + + kparams->n_prefetch = best_n_prefetch; + + if (total_size <= vtcm_budget) { + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + goto done_quant; + } + HEX_VERBOSE("ggml-hex: %s HVX tiled path VTCM size needed (%zu) > budget (%zu), falling back to HVX flat\n", sess->name.c_str(), total_size, vtcm_budget); + } + + // Flat HVX fallback + { + kparams->src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t total_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + kparams->n_prefetch = 16; + kparams->vtcm_size = total_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + } + } + + done_quant:; + } else if (wtype == GGML_TYPE_F16) { + // F16 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F16_F16_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 2, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + if (src1->type == GGML_TYPE_F32) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F32_DDR; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F16_F16_DDR; + } + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } else { + // F32 HVX + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = ggml_is_permuted(src0) || ggml_is_permuted(src1); + + size_t vtcm_src0_size = 0, vtcm_src1_size = 0, vtcm_dst_size = 0; + size_t vtcm_size = htp_mm_hvx_get_vtcm_sizes( + HTP_MM_KERNEL_HVX_F32_F32_VTCM, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + + if (!is_batched && !is_permuted && vtcm_size <= vtcm_budget) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_VTCM; + kparams->src1_row_size = hex_round_up(ne10 * 4, 128); + kparams->vtcm_size = vtcm_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_F32_F32_DDR; + kparams->src1_row_size = src1->nb[1]; + size_t ddr_size = htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, wtype, ne10, src1_nrows, sess->n_threads, + dst->nb[1], src0->nb[1], src1->nb[1], 16, &vtcm_src0_size, &vtcm_src1_size, &vtcm_dst_size + ); + kparams->vtcm_size = ddr_size; + kparams->vtcm_src0_size = vtcm_src0_size; + kparams->vtcm_src1_size = vtcm_src1_size; + kparams->vtcm_dst_size = vtcm_dst_size; + kparams->n_prefetch = 16; + } + } +} + +static void ggml_hexagon_precompute_matmul_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, + const struct ggml_tensor * src1, + const struct ggml_tensor * dst, + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int ne00 = src0->ne[0]; + const int ne01 = src0->ne[1]; + const int ne02 = src0->ne[2]; + const int ne03 = src0->ne[3]; + + const int ne10 = src1->ne[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const int ne13 = src1->ne[3]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const int ne00_padded = is_repack ? hex_round_up(ne00, 32) : ne00; + const int ne01_padded = is_repack ? hex_round_up(ne01, 32) : ne01; + const int ne11_padded = hex_round_up(ne11, 32); + + const bool is_matmul_id = (dst->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (ne02 * ne03 > 1 || ne12 * ne13 > 1); + + const size_t vtcm_budget = sess->vtcm_size; + + // Check HMX eligibility and try precomputing HMX parameters + bool hmx_enabled = (sess->n_hmx > 0) && (opt_mm_select >= 3); + if (hmx_enabled && ggml_hexagon_matmul_is_hmx_eligible(src0, src1, dst, ne01_padded, is_matmul_id, is_batched)) { + if (ggml_hexagon_precompute_hmx_mm_params(sess, src0, src1, dst, wtype, ne00_padded, ne01_padded, ne02, ne11, ne12, ne11_padded, is_matmul_id, is_batched, vtcm_budget, kparams)) { + goto finalize; + } + } + + // Fallback to HVX parameter computation + ggml_hexagon_precompute_hvx_mm_params(sess, src0, src1, dst, wtype, ne02, ne03, ne10, ne11, ne12, ne13, is_matmul_id, vtcm_budget, kparams); + +finalize: + kparams->div_ne12_ne1 = init_fastdiv_values(ne12 * ne11); + kparams->div_ne1 = init_fastdiv_values(ne11); + kparams->div_r2 = init_fastdiv_values(ne02 > 0 ? ne12 / ne02 : 1); + kparams->div_r3 = init_fastdiv_values(ne03 > 0 ? ne13 / ne03 : 1); + kparams->div_ne11 = init_fastdiv_values(ne11); +} + +static void ggml_hexagon_precompute_fused_qkv_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wk + const struct ggml_tensor * src1, // x + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + size_t src3_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t src3_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + src3_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src3_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + size_t src3_sz = src3_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz + src3_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_src3_size = src3_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz + src3_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + +static void ggml_hexagon_precompute_fused_ffn_params( + const struct ggml_hexagon_session * sess, + const struct ggml_tensor * src0, // Wgate + const struct ggml_tensor * src1, // y + struct htp_mm_kernel_params * kparams +) { + memset(kparams, 0, sizeof(*kparams)); + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + + const int ne10 = src1->ne[0]; + const int src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + const size_t src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + size_t src0_sz_per_thread = 0; + size_t src2_sz_per_thread = 0; + uint32_t best_n_prefetch = 16; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = hex_round_up(ne10, 32) / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + size_t src1_sz = src1_sz_per_thread; + + const uint32_t max_prefetch = (src1_nrows > HTP_MM_HMX_MIN_NROWS) ? 2 : 16; + best_n_prefetch = 2; + for (uint32_t d = max_prefetch; d >= 2; d /= 2) { + size_t repacked_vtcm_size = hex_round_up(d * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + size_t src0_sz = repacked_vtcm_size * sess->n_threads; + size_t src2_sz = hex_round_up(d * tile_row_size, 128) * sess->n_threads; + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + + if (tiled_vtcm_size <= sess->vtcm_size) { + best_n_prefetch = d; + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(d * tile_row_size, 128); + break; + } + } + if (best_n_prefetch == 2 && src0_sz_per_thread == 0) { + size_t repacked_vtcm_size = hex_round_up(2 * tile_row_size, 128); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + src2_sz_per_thread = hex_round_up(2 * tile_row_size, 128); + } + } else { + best_n_prefetch = 16; + src0_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + src2_sz_per_thread = hex_round_up(best_n_prefetch * src0_row_size_padded, 128); + } + + size_t src1_sz_per_thread = hex_round_up(src1_row_size * src1_nrows, 128); + + size_t src0_sz = src0_sz_per_thread * sess->n_threads; + size_t src1_sz = src1_sz_per_thread; + size_t src2_sz = src2_sz_per_thread * sess->n_threads; + + size_t tiled_vtcm_size = src0_sz + src1_sz + src2_sz; + bool try_tiled = (opt_mm_select >= 2); + if (try_tiled && tiled_vtcm_size <= sess->vtcm_size) { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW; + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = tiled_vtcm_size; + kparams->n_prefetch = best_n_prefetch; + } else { + kparams->kernel_type = HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT; + size_t flat_src1_row_size = (wtype == GGML_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + size_t flat_src1_sz = hex_round_up(flat_src1_row_size * src1_nrows, 128); + kparams->vtcm_src0_size = src0_sz; + kparams->vtcm_src1_size = flat_src1_sz; + kparams->vtcm_src2_size = src2_sz; + kparams->vtcm_size = src0_sz + flat_src1_sz + src2_sz; + kparams->n_prefetch = best_n_prefetch; + } +} + static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; @@ -2675,12 +2625,13 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } - if (ggml_nrows(src0) > 16 * 1024) { - return false; // typically the lm-head which would be too large for VTCM + // hardcoded limit to refuse the lm-head for now + if (src0->ne[1] > 32768) { + return false; } - if (ggml_nrows(src1) > 1024 || src1->ne[2] != 1 || src1->ne[3] != 1) { - return false; // no huge batches or broadcasting (for now) + if (src1->ne[2] != 1 || src1->ne[3] != 1) { + return false; // no broadcasting (for now) } // src0 (weights) must be repacked @@ -2691,16 +2642,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s case GGML_TYPE_F16: if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F16 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; case GGML_TYPE_F32: @@ -2708,22 +2654,24 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; } if (src0->nb[1] < src0->nb[0]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: permuted F32 src0 not supported\n"); return false; } if (src1->ne[2] < src0->ne[2] || src1->ne[3] < src0->ne[3]) { - GGML_LOG_DEBUG("ggml_hexagon_supported_mul_mat: src1 broadcasting not supported\n"); return false; } - if (ggml_nrows(src1) > 1024) { - return false; // no huge batches (for now) - } break; default: return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -2757,6 +2705,13 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session return false; } + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_matmul_params(sess, src0, src1, dst, &kparams); + if ((size_t)kparams.vtcm_size > sess->vtcm_size) { + HEX_VERBOSE("ggml-hex: %s supported MUL_MAT_ID VTCM size needed (%d) > budget (%zu)\n", sess->c_name(), kparams.vtcm_size, sess->vtcm_size); + return false; + } + return true; } @@ -3288,47 +3243,172 @@ static inline bool op_is_compute(ggml_tensor *node) return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE); } +static bool is_hmx_eligible(const ggml_tensor * t) { + if (opt_nhmx == 0) { return false; } + + const ggml_tensor * src0 = t->src[0]; + const ggml_tensor * src1 = t->src[1]; + + const int wtype = src0->type; + const bool is_repack = ggml_hexagon_is_repack_type((ggml_type) wtype); + const bool is_matmul_id = (t->op == GGML_OP_MUL_MAT_ID); + const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); + + const int ne01_padded = is_repack ? hex_round_up(src0->ne[1], 32) : src0->ne[1]; + + return ggml_hexagon_matmul_is_hmx_eligible(src0, src1, t, ne01_padded, is_matmul_id, is_batched); +} + +static bool is_mergeable_mul_mat(const ggml_tensor * t) { + if (!t || t->op != GGML_OP_MUL_MAT) return false; + if (t->src[1]->type != GGML_TYPE_F32) return false; + return ggml_is_quantized(t->src[0]->type) && !is_hmx_eligible(t); +} + +static bool is_mergeable_mul_mat_pair(const ggml_tensor * n1, const ggml_tensor * n2) { + if (!is_mergeable_mul_mat(n1) || !is_mergeable_mul_mat(n2)) { + return false; + } + if (n1->src[1] != n2->src[1]) { + return false; + } + if (n1->src[0]->ne[0] != n2->src[0]->ne[0] || + n1->src[0]->ne[1] != n2->src[0]->ne[1]) { + return false; + } + if (n1->src[0]->type != n2->src[0]->type) { + return false; + } + return true; +} + +static bool is_qkv_mergeable(const ggml_tensor * n_q, const ggml_tensor * n_k, const ggml_tensor * n_v) { + if (!is_mergeable_mul_mat(n_q) || !is_mergeable_mul_mat(n_k) || !is_mergeable_mul_mat(n_v)) { + return false; + } + if (n_q->src[1] != n_k->src[1] || n_q->src[1] != n_v->src[1]) { + return false; + } + if (n_q->src[0]->type != n_k->src[0]->type || n_q->src[0]->type != n_v->src[0]->type) { + return false; + } + if (n_k->src[0]->ne[0] != n_v->src[0]->ne[0] || + n_k->src[0]->ne[1] != n_v->src[0]->ne[1]) { + return false; + } + if (n_q->src[0]->ne[0] != n_k->src[0]->ne[0]) { + return false; + } + return true; +} + +static bool try_fuse_node(const ggml_hexagon_session * sess, const ggml_cgraph * graph, int & i, std::vector & nodes) { + if (!opt_opfusion) { + return false; + } + + ggml_tensor * n = graph->nodes[i]; + ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + + if (n->op == GGML_OP_RMS_NORM && next_node) { + if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + htp_opnode node(n, {}, HTP_OP_RMS_NORM_MUL); + node.add_fused(next_node); + nodes.push_back(std::move(node)); + i++; // skip the fused MUL node + return true; + } + } + + if (is_mergeable_mul_mat(n)) { + ggml_tensor * n1 = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; + ggml_tensor * n2 = (i + 2 < graph->n_nodes) ? graph->nodes[i + 2] : nullptr; + if (is_qkv_mergeable(n, n1, n2)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_qkv_params(sess, n1->src[0], n1->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + // Reorder to KVQ: K (n1), V (n2), Q (n) + htp_opnode node(n1, {}, HTP_OP_MUL_MAT_QKV); + node.add_fused(n2, true); + node.add_fused(n, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 2; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip QKV fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + if (is_mergeable_mul_mat_pair(n, n1)) { + struct htp_mm_kernel_params kparams; + ggml_hexagon_precompute_fused_ffn_params(sess, n->src[0], n->src[1], &kparams); + if ((size_t)kparams.vtcm_size <= sess->vtcm_size) { + htp_opnode node(n, {}, HTP_OP_MUL_MAT_FFN); + node.add_fused(n1, true); + memcpy(node.kernel_params, &kparams, sizeof(kparams)); + nodes.push_back(std::move(node)); + i += 1; + return true; + } else { + HEX_VERBOSE("ggml-hex: skip FFN fusion because VTCM needed (%d) > budget (%zu)\n", + kparams.vtcm_size, sess->vtcm_size); + } + } + } + + return false; +} + static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, ggml_cgraph * graph) { auto sess = static_cast(backend->context); HEX_VERBOSE("ggml-hex: %s graph-compute n_nodes %d\n", sess->c_name(), graph->n_nodes); - std::vector nodes; - nodes.reserve(graph->n_nodes); + const std::vector * nodes_ptr = nullptr; + std::vector computed_nodes; - // Fusion - for (int i = 0; i < graph->n_nodes; ++i) { - ggml_tensor * n = graph->nodes[i]; - if (!op_is_compute(n)) { - continue; - } + // Check for cache hit + bool cache_hit = (graph->uid != 0 && sess->cached_graph.uid == graph->uid); + if (cache_hit) { + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + computed_nodes.reserve(graph->n_nodes); - ggml_tensor * next_node = (i + 1 < graph->n_nodes) ? graph->nodes[i + 1] : nullptr; - - htp_opnode node = { - /*.node =*/ n, - /*.fused =*/ {}, - /*.opcode =*/ HTP_OP_INVALID - }; - - if (n->op == GGML_OP_RMS_NORM && next_node) { - if (next_node->op == GGML_OP_MUL && op_is_compute(next_node) && ggml_can_fuse(graph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { - node.add_fused(next_node); - node.opcode = HTP_OP_RMS_NORM_MUL; - i++; // skip the fused MUL node + // Fuse and finalize + for (int i = 0; i < graph->n_nodes; ++i) { + ggml_tensor * n = graph->nodes[i]; + if (!op_is_compute(n)) { + continue; } - } - if (node.opcode == HTP_OP_INVALID) { + if (try_fuse_node(sess, graph, i, computed_nodes)) { + continue; + } + + htp_opnode node(n, {}, HTP_OP_INVALID); node.opcode = op_remap_to_htp(n); + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID) { + ggml_hexagon_precompute_matmul_params(sess, + node.node->src[0], node.node->src[1], node.node, + (struct htp_mm_kernel_params *)node.kernel_params + ); + } + computed_nodes.push_back(std::move(node)); } - nodes.push_back(std::move(node)); + if (graph->uid != 0) { + sess->cached_graph.uid = graph->uid; + sess->cached_graph.htp_nodes = std::move(computed_nodes); + nodes_ptr = &sess->cached_graph.htp_nodes; + } else { + nodes_ptr = &computed_nodes; + } } // Queue and execute if (opt_opstage & HTP_OPSTAGE_QUEUE) { - for (const auto & node : nodes) { + for (const auto & node : *nodes_ptr) { sess->enqueue_op(node); } } @@ -3991,16 +4071,19 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { const char * str_opbatch = getenv("GGML_HEXAGON_OPBATCH"); const char * str_opqueue = getenv("GGML_HEXAGON_OPQUEUE"); const char * str_oppoll = getenv("GGML_HEXAGON_OPPOLL"); - const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); + const char * str_opfusion = getenv("GGML_HEXAGON_OPFUSION"); const char * str_opfilter = getenv("GGML_HEXAGON_OPFILTER"); const char * str_profile = getenv("GGML_HEXAGON_PROFILE"); const char * str_etm = getenv("GGML_HEXAGON_ETM"); const char * str_nhvx = getenv("GGML_HEXAGON_NHVX"); const char * str_use_hmx = getenv("GGML_HEXAGON_USE_HMX"); + const char * str_nhmx = getenv("GGML_HEXAGON_NHMX"); + const char * str_mm_select = getenv("GGML_HEXAGON_MM_SELECT"); const char * str_ndev = getenv("GGML_HEXAGON_NDEV"); const char * str_arch = getenv("GGML_HEXAGON_ARCH"); const char * str_vmem = getenv("GGML_HEXAGON_VMEM"); const char * str_mbuf = getenv("GGML_HEXAGON_MBUF"); + const char * str_optrace = getenv("GGML_HEXAGON_OPTRACE"); // Init Arch first since it affects other defaults if (!str_arch) { @@ -4029,12 +4112,14 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) { opt_opstage = str_opstage ? strtoul(str_opstage, NULL, 0) : opt_opstage; opt_opbatch = str_opbatch ? strtoul(str_opbatch, NULL, 0) : opt_opbatch; opt_opqueue = str_opqueue ? strtoul(str_opqueue, NULL, 0) : opt_opqueue; - opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; opt_optrace = str_optrace ? strtoul(str_optrace, NULL, 0) : (opt_opbatch * 128); + opt_oppoll = str_oppoll ? strtoul(str_oppoll, NULL, 0) : opt_oppoll; + opt_opfusion = str_opfusion ? atoi(str_opfusion) : opt_opfusion; opt_profile = str_profile ? atoi(str_profile) : 0; opt_etm = str_etm ? atoi(str_etm) : 0; opt_nhvx = str_nhvx ? strtoul(str_nhvx, NULL, 0) : opt_nhvx; - opt_use_hmx = str_use_hmx ? atoi(str_use_hmx) : opt_use_hmx; + opt_nhmx = str_nhmx ? atoi(str_nhmx) : (str_use_hmx ? atoi(str_use_hmx) : opt_nhmx); + opt_mm_select = str_mm_select ? atoi(str_mm_select) : opt_mm_select; opt_ndev = str_ndev ? strtoul(str_ndev, NULL, 0) : opt_ndev; opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : opt_hostbuf; opt_mbuf = str_mbuf ? strtoul(str_mbuf, NULL, 0) * MiB : opt_mbuf; diff --git a/ggml/src/ggml-hexagon/htp-opnode.h b/ggml/src/ggml-hexagon/htp-opnode.h index 52c727c620..6fe23b0d6a 100644 --- a/ggml/src/ggml-hexagon/htp-opnode.h +++ b/ggml/src/ggml-hexagon/htp-opnode.h @@ -5,10 +5,12 @@ #include "ggml-backend-impl.h" #include "ggml-common.h" +#include #include #include #include #include "htp-ops.h" +#include "htp/matmul-ops.h" struct htp_opnode { ggml_tensor * node = nullptr; @@ -17,6 +19,13 @@ struct htp_opnode { htp_op_code opcode = HTP_OP_INVALID; + std::vector extra_dsts; + + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS] = {0}; + + htp_opnode(ggml_tensor * node = nullptr, std::vector fused = {}, htp_op_code opcode = HTP_OP_INVALID, std::vector extra_dsts = {}) + : node(node), fused(std::move(fused)), opcode(opcode), extra_dsts(std::move(extra_dsts)) {} + ggml_op op() const { return node->op; } @@ -25,6 +34,26 @@ struct htp_opnode { return fused.empty() ? node : fused.back(); } + void add_fused(ggml_tensor * t, bool extra_dst = false) { + fused.push_back(t); + if (extra_dst) { + extra_dsts.push_back(t); + } + } + + std::vector get_outputs() const { + std::vector res; + if (extra_dsts.empty()) { + res.push_back(dst()); + } else { + res.push_back(node); + for (const auto * x : extra_dsts) { + res.push_back(x); + } + } + return res; + } + const ggml_tensor * src0() const { return node->src[0]; } @@ -37,10 +66,6 @@ struct htp_opnode { return ggml_op_is_empty(node->op); } - void add_fused(ggml_tensor * t) { - fused.push_back(t); - } - bool stackable() const { switch (this->op()) { case GGML_OP_MUL_MAT: @@ -131,87 +156,117 @@ struct htp_opformat { char types[16 * GGML_MAX_SRC]; char buffs[64 * GGML_MAX_SRC]; char names[64 * GGML_MAX_SRC]; + char kparams[128]; - int format_tensor_dims(char * str, const struct ggml_tensor * t) { + int format_tensor_dims(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%d:%d", (int) t->ne[0], (int) t->ne[1]); + return snprintf(str, max_size, "%d:%d", (int) t->ne[0], (int) t->ne[1]); } else { - return sprintf(str, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); + return snprintf(str, max_size, "%d:%d:%d:%d", (int) t->ne[0], (int) t->ne[1], (int) t->ne[2], (int) t->ne[3]); } } - void format_op_dims(char * str, const htp_opnode & node) { + void format_op_dims(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_dims(p, inputs[0]); + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_dims(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_dims(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_dims(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_dims(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - int format_tensor_strides(char * str, const struct ggml_tensor * t) { + int format_tensor_strides(char * str, size_t max_size, const struct ggml_tensor * t) { if (!t) { - return sprintf(str, "NONE"); + return snprintf(str, max_size, "NONE"); } const char * c = ggml_is_contiguous(t) ? "" : "!"; if (t->ne[2] == 1 && t->ne[3] == 1) { - return sprintf(str, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); + return snprintf(str, max_size, "%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], c); } else { - return sprintf(str, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); + return snprintf(str, max_size, "%zu:%zu:%zu:%zu%s", (size_t) t->nb[0], (size_t) t->nb[1], (size_t) t->nb[2], (size_t) t->nb[3], c); } } - void format_op_strides(char * str, const htp_opnode & node) { + void format_op_strides(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += format_tensor_strides(p, inputs[0]); + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[0]), (size_t)(p_end - p)); for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += format_tensor_strides(p, inputs[i]); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)format_tensor_strides(p, p_end - p, inputs[i]), (size_t)(p_end - p)); + } } - p += sprintf(p, " -> "); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } char self[64]; - format_tensor_strides(self, node.dst()); - p += sprintf(p, "%s", self); + format_tensor_strides(self, sizeof(self), node.dst()); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", self), (size_t)(p_end - p)); + } } - void format_op_types(char * str, const htp_opnode & node) { + void format_op_types(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? ggml_type_name(inputs[0]->type) : "NONE"), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? ggml_type_name(inputs[i]->type) : "NONE"), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", ggml_type_name(node.dst()->type)); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", ggml_type_name(node.dst()->type)), (size_t)(p_end - p)); + } } const char * tensor_buff_name(const struct ggml_tensor * t) { @@ -221,51 +276,102 @@ struct htp_opformat { return "NONE"; } - void format_op_buffs(char * str, const htp_opnode & node) { + void format_op_buffs(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", tensor_buff_name(inputs[0])); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", tensor_buff_name(inputs[i])); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[0])), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(inputs[i])), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", tensor_buff_name(node.dst())); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", tensor_buff_name(node.dst())), (size_t)(p_end - p)); + } } - void format_op_names(char * str, const htp_opnode & node) { + void format_op_names(char * str, size_t max_size, const htp_opnode & node) { char * p = str; + char * p_end = str + max_size; auto inputs = node.get_inputs(); if (!inputs.empty()) { - p += sprintf(p, "%s", inputs[0] ? inputs[0]->name : "NONE"); - - for (size_t i = 1; i < inputs.size(); i++) { - p += sprintf(p, " x "); - p += sprintf(p, "%s", inputs[i] ? inputs[i]->name : "NONE"); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[0] ? inputs[0]->name : "NONE"), (size_t)(p_end - p)); } - p += sprintf(p, " -> "); + for (size_t i = 1; i < inputs.size(); i++) { + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " x "), (size_t)(p_end - p)); + } + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", inputs[i] ? inputs[i]->name : "NONE"), (size_t)(p_end - p)); + } + } + + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, " -> "), (size_t)(p_end - p)); + } } - p += sprintf(p, "%s", node.dst()->name); + if (p < p_end) { + p += std::min((size_t)snprintf(p, p_end - p, "%s", node.dst()->name), (size_t)(p_end - p)); + } + } + void format_kernel_params(char * str, size_t max_size, const htp_opnode & node) { + if (node.opcode == HTP_OP_MUL_MAT || node.opcode == HTP_OP_MUL_MAT_ID || + node.opcode == HTP_OP_MUL_MAT_QKV || node.opcode == HTP_OP_MUL_MAT_FFN) { + const auto * kparams = (const struct htp_mm_kernel_params *) node.kernel_params; + const char * path = "unknown"; + int32_t type = kparams->kernel_type; + if (type == HTP_MM_KERNEL_HMX_2D || type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + path = "hmx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + type == HTP_MM_KERNEL_HVX_QUANT_ROW || type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + path = "hvx-tiled"; + } else if (type == HTP_MM_KERNEL_HVX_F16_F16_DDR || type == HTP_MM_KERNEL_HVX_F16_F32_DDR || + type == HTP_MM_KERNEL_HVX_F32_F32_DDR || type == HTP_MM_KERNEL_HVX_F32_F16_DDR || + type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + path = "hvx-flat"; + } + snprintf(str, max_size, "%s vtcm %d", path, (int) kparams->vtcm_size); + } else { + snprintf(str, max_size, "----"); + } } void format(const htp_opnode & node) { - format_op_dims(dims, node); - format_op_strides(strides, node); - format_op_types(types, node); - format_op_buffs(buffs, node); - format_op_names(names, node); + format_op_dims(dims, sizeof(dims), node); + format_op_strides(strides, sizeof(strides), node); + format_op_types(types, sizeof(types), node); + format_op_buffs(buffs, sizeof(buffs), node); + format_op_names(names, sizeof(names), node); + format_kernel_params(kparams, sizeof(kparams), node); } - htp_opformat() {} + htp_opformat() { + strides[0] = '\0'; + dims[0] = '\0'; + types[0] = '\0'; + buffs[0] = '\0'; + names[0] = '\0'; + kparams[0] = '\0'; + } htp_opformat(const htp_opnode & node) { format(node); } }; diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 31ba527623..c48a5b86e3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -19,43 +19,9 @@ add_library(${HTP_LIB} SHARED htp_iface_skel.c worker-pool.c hex-dma.c -) - -target_compile_definitions(${HTP_LIB} PRIVATE - $,HTP_DEBUG=1,NDEBUG=1> - $,FARF_HIGH=1,> - FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) - -if (GGML_HEXAGON_FA_EXP2_HF) - message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") - target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) -endif() - -# HMX acceleration: available on v73+ architectures -set(HTP_HMX_VERSIONS v73 v75 v79 v81) -list(FIND HTP_HMX_VERSIONS ${DSP_VERSION} _hmx_idx) - -if (_hmx_idx GREATER_EQUAL 0) - target_sources(${HTP_LIB} PRIVATE - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - ) - - # -mhmx enables HMX instruction set (needed by files that include hmx-utils.h) - set_source_files_properties( - hmx-flash-attn-ops.c - hmx-matmul-ops.c - hmx-queue.c - PROPERTIES COMPILE_OPTIONS "-mhmx" - ) - - target_compile_definitions(${HTP_LIB} PRIVATE HTP_HAS_HMX=1) -endif() - -build_idl(htp_iface.idl ${HTP_LIB}) - -target_sources(${HTP_LIB} PRIVATE + hmx-queue.c + flash-attn-ops.c + hmx-flash-attn-ops.c matmul-ops.c binary-ops.c unary-ops.c @@ -63,7 +29,6 @@ target_sources(${HTP_LIB} PRIVATE softmax-ops.c act-ops.c rope-ops.c - flash-attn-ops.c set-rows-ops.c get-rows-ops.c cpy-ops.c @@ -79,6 +44,17 @@ target_sources(${HTP_LIB} PRIVATE pad-ops.c ) +target_compile_definitions(${HTP_LIB} PRIVATE + $,HTP_DEBUG=1,NDEBUG=1> + $,FARF_HIGH=1,>) + +if (GGML_HEXAGON_FA_EXP2_HF) + message(STATUS "ggml-htp: HMX_FA_USE_EXP2_HF=1 (use FP16 exp2 polynomial in FA softmax)") + target_compile_definitions(${HTP_LIB} PRIVATE HMX_FA_USE_EXP2_HF=1) +endif() + +build_idl(htp_iface.idl ${HTP_LIB}) + set_target_properties(${HTP_LIB} PROPERTIES EXPORT_COMPILE_COMMANDS ON) install(TARGETS ${HTP_LIB}) diff --git a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake index ed5c198468..3eff2a3986 100644 --- a/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake +++ b/ggml/src/ggml-hexagon/htp/cmake-toolchain.cmake @@ -3,7 +3,7 @@ if (HEXAGON_TOOLCHAIN_INCLUDED) endif() set(HEXAGON_TOOLCHAIN_INCLUDED true) -#Cross Compiling for Hexagon +# Cross Compiling for Hexagon set(HEXAGON TRUE) set(CMAKE_SYSTEM_NAME QURT) set(CMAKE_SYSTEM_PROCESSOR Hexagon) @@ -14,7 +14,6 @@ set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) set(CUSTOM_RUNELF_PATH "") -#To fix backward compatibility with EAI addon. if (NOT HEXAGON_SDK_ROOT) set(HEXAGON_SDK_ROOT $ENV{HEXAGON_SDK_ROOT}) endif() @@ -31,7 +30,6 @@ endif() file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT) file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT) -#Get the Binary extension of the Hexagon Toolchain if(CMAKE_HOST_SYSTEM_NAME STREQUAL Windows) set(HEXAGON_TOOLCHAIN_SUFFIX .exe) endif() @@ -48,12 +46,12 @@ set(CMAKE_TRY_COMPILE_PLATFORM_VARIABLES HEXAGON_TOOLS_ROOT ) -#QURT Related includes and linker flags +# QURT Related includes and linker flags set(V_ARCH ${HEXAGON_ARCH}) set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/ADSP${V_ARCH}MP${V_ARCH_EXTN}") set(_QURT_INSTALL_DIR "${HEXAGON_SDK_ROOT}/rtos/qurt/compute${V_ARCH}${V_ARCH_EXTN}") -if( ${TREE} MATCHES PAKMAN ) +if (${TREE} MATCHES PAKMAN) set(_QURT_INSTALL_DIR "${QURT_IMAGE_DIR}/compute${V_ARCH}${V_ARCH_EXTN}") endif() message(DEBUG "_QURT_INSTALL_DIR:${_QURT_INSTALL_DIR}") @@ -83,11 +81,9 @@ set(QURT_START_LINK_LIBS ) STRING(REPLACE ";" " " QURT_START_LINK_LIBS "${QURT_START_LINK_LIBS}") -set(QURT_END_LINK_LIBS - ${TARGET_DIR}/fini.o - ) +set(QURT_END_LINK_LIBS ${TARGET_DIR}/fini.o) -#Non QURT related includes and linker flags +# Non QURT related includes and linker flags set(TARGET_DIR_NOOS "${HEXAGON_TOOLCHAIN}/Tools/target/hexagon/lib/${HEXAGON_ARCH}") @@ -99,8 +95,10 @@ if (NOT NO_WRAP_MEM_API) set(WRAP_MEMALIGN -Wl,--wrap=memalign) endif() +set(ARCH_FLAGS "-mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -mhmx") + set(PIC_SHARED_LD_FLAGS - -mcpu=${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} + ${ARCH_FLAGS} -G0 -fpic -Wl,-Bsymbolic @@ -120,13 +118,13 @@ STRING(REPLACE ";" " " PIC_SHARED_LD_FLAGS "${PIC_SHARED_LD_FLAGS}") set(HEXAGON_PIC_SHARED_LINK_OPTIONS "${PIC_SHARED_LD_FLAGS}") -#System include paths +# System include paths include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/incs/stddef) include_directories(SYSTEM ${HEXAGON_SDK_ROOT}/ipc/fastrpc/incs) -#LLVM toolchain setup -#Compiler paths, options and architecture +# LLVM toolchain setup +# Compiler paths, options and architecture set(CMAKE_C_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_CXX_COMPILER ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-clang++${HEXAGON_TOOLCHAIN_SUFFIX}) set(CMAKE_AR ${HEXAGON_TOOLCHAIN}/Tools/bin/hexagon-ar${HEXAGON_TOOLCHAIN_SUFFIX}) @@ -137,8 +135,8 @@ set(CMAKE_PREFIX_PATH ${HEXAGON_TOOLCHAIN}/Tools/target/hexagon) set(CMAKE_SHARED_LIBRARY_SONAME_C_FLAG "-Wl,-soname,") set(CMAKE_SHARED_LIBRARY_SONAME_CXX_FLAG "-Wl,-soname,") -#Compiler Options -set(COMMON_FLAGS "-mcpu=hexagon${V_ARCH} -m${V_ARCH} -mhvx=${V_ARCH} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") +# Compiler Options +set(COMMON_FLAGS "${ARCH_FLAGS} -fvectorize -flto -Wall -Werror -fno-zero-initialized-in-bss -G0 -fdata-sections -fpic ${XQF_ARGS}") set(CMAKE_CXX_FLAGS_DEBUG "${COMMON_FLAGS} -O0 -D_DEBUG -g") set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${COMMON_FLAGS} -O2 -g") diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index b7511cd644..65f7844ae3 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -18,7 +18,8 @@ #include "htp-ctx.h" #include "htp-ops.h" #include "htp-ops.h" -#include "hmx-ops.h" + +int hmx_flash_attn_ext(struct htp_ops_context * octx); // Must be multiple of 32 #define FLASH_ATTN_BLOCK_SIZE (32 * 2) @@ -633,7 +634,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { return HTP_STATUS_NO_SUPPORT; } -#ifdef HTP_HAS_HMX // HMX path: head_dim multiple of 64, F16 KV, and no sinks if (k->type == HTP_TYPE_F16 && v->type == HTP_TYPE_F16 && k->ne[0] % 64 == 0 && v->ne[0] % 64 == 0 && octx->src[4] == NULL) { int ret = hmx_flash_attn_ext(octx); @@ -642,7 +642,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { } // VTCM too small or other failure -> fall through to HVX path } -#endif struct htp_fa_context factx; factx.octx = octx; diff --git a/ggml/src/ggml-hexagon/htp/hex-common.h b/ggml/src/ggml-hexagon/htp/hex-common.h new file mode 100644 index 0000000000..4714486a04 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hex-common.h @@ -0,0 +1,80 @@ +#ifndef HEX_COMMON_H +#define HEX_COMMON_H + +#include +#include +#include + +#ifndef SIZE_MAX +#define SIZE_MAX ((size_t)-1) +#endif + +#ifndef MAX +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#endif + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +static inline uint32_t hex_ceil_pow2(uint32_t x) { + if (x <= 1) { return 1; } + int p = 2; + x--; + while (x >>= 1) { p <<= 1; } + return p; +} + +static inline size_t hmx_ceil_div(size_t num, size_t den) { + return (num + den - 1) / den; +} + +static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { + return ((size_t) addr & (align - 1)) == 0; +} + +static inline size_t hex_align_up(size_t v, size_t align) { + return hmx_ceil_div(v, align) * align; +} + +static inline size_t hex_align_down(size_t v, size_t align) { + return (v / align) * align; +} + +static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { + uint32_t left_off = (size_t) addr & (chunk_size - 1); + uint32_t right_off = left_off + n; + return right_off <= chunk_size; +} + +static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { + return m * ((n + m - 1) / m); +} + +static inline size_t hex_smin(size_t a, size_t b) { + return a < b ? a : b; +} + +static inline size_t hex_smax(size_t a, size_t b) { + return a > b ? a : b; +} + +static inline void hex_swap_ptr(void ** p1, void ** p2) { + void * t = *p1; + *p1 = *p2; + *p2 = t; +} + +static inline bool hex_mul_overflow(size_t a, size_t b, size_t *out) { + if (a != 0 && b > SIZE_MAX / a) return true; + *out = a * b; + return false; +} + +static inline bool hex_add_overflow(size_t a, size_t b, size_t *out) { + if (a > SIZE_MAX - b) return true; + *out = a + b; + return false; +} + +#endif // HEX_COMMON_H diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index 93c21ebe5e..8031a5679c 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -5,6 +5,7 @@ #include #include #include +#include "hex-utils.h" #include "hex-profile.h" @@ -127,13 +128,8 @@ static inline dma_ptr dma_make_ptr(void *dst, const void *src) return p; } -#if __HVX_ARCH__ < 73 -static const uint32_t dma_src_l2_bypass_on = 1; -static const uint32_t dma_dst_l2_bypass_on = 0; -#else static const uint32_t dma_src_l2_bypass_on = 1; static const uint32_t dma_dst_l2_bypass_on = 1; -#endif static inline bool dma_queue_push_single_1d(dma_queue * q, dma_ptr dptr, size_t size) { if (((q->push_idx + 1) & q->idx_mask) == q->pop_idx) { diff --git a/ggml/src/ggml-hexagon/htp/hex-utils.h b/ggml/src/ggml-hexagon/htp/hex-utils.h index 8e6e3ea750..07930bef6e 100644 --- a/ggml/src/ggml-hexagon/htp/hex-utils.h +++ b/ggml/src/ggml-hexagon/htp/hex-utils.h @@ -11,14 +11,7 @@ #include "hex-fastdiv.h" #include "hex-dump.h" - -#ifndef MAX -#define MAX(a, b) ((a) > (b) ? (a) : (b)) -#endif - -#ifndef MIN -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#endif +#include "hex-common.h" static inline uint64_t hex_get_cycles() { uint64_t cycles = 0; @@ -32,54 +25,6 @@ static inline uint64_t hex_get_pktcnt() { return pktcnt; } -static inline uint32_t hex_ceil_pow2(uint32_t x) { - if (x <= 1) { return 1; } - int p = 2; - x--; - while (x >>= 1) { p <<= 1; } - return p; -} - -static inline size_t hmx_ceil_div(size_t num, size_t den) { - return (num + den - 1) / den; -} - -static inline int32_t hex_is_aligned(const void * addr, uint32_t align) { - return ((size_t) addr & (align - 1)) == 0; -} - -static inline size_t hex_align_up(size_t v, size_t align) { - return hmx_ceil_div(v, align) * align; -} - -static inline size_t hex_align_down(size_t v, size_t align) { - return (v / align) * align; -} - -static inline int32_t hex_is_one_chunk(void * addr, uint32_t n, uint32_t chunk_size) { - uint32_t left_off = (size_t) addr & (chunk_size - 1); - uint32_t right_off = left_off + n; - return right_off <= chunk_size; -} - -static inline uint32_t hex_round_up(uint32_t n, uint32_t m) { - return m * ((n + m - 1) / m); -} - -static inline size_t hex_smin(size_t a, size_t b) { - return a < b ? a : b; -} - -static inline size_t hex_smax(size_t a, size_t b) { - return a > b ? a : b; -} - -static inline void hex_swap_ptr(void ** p1, void ** p2) { - void * t = *p1; - *p1 = *p2; - *p2 = t; -} - static inline void hex_l2fetch(const void * p, uint32_t width, uint32_t stride, uint32_t height) { const uint64_t control = Q6_P_combine_RR(stride, Q6_R_combine_RlRl(width, height)); Q6_l2fetch_AP((void *) p, control); diff --git a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c index 986dde148d..996fd59757 100644 --- a/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/hmx-flash-attn-ops.c @@ -49,7 +49,7 @@ // g_br = hex_align_up(gqa_factor * Br, 32) replaces Br for all Q/O/S/P/D dimensions. // Layout: Q + O_ping + O_pong + K_dma*2 + V_dma*2 + K_tile + V_tile + S + P + D + vectors + scales // Mask is DMA'd into a VTCM buffer (Br rows per KV block) to avoid DDR reads in softmax. -static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool use_pipeline) { +static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, size_t Br, size_t Bc, size_t n_threads, bool pipeline) { const size_t g_br = hex_align_up(gqa_factor * Br, HMX_FP16_TILE_N_ROWS); const size_t q_tile_size = hex_align_up(g_br * DK * sizeof(__fp16), 4096); // Q: [g_br, DK] const size_t o_tile_size = hex_align_up(g_br * DV * sizeof(__fp16), 4096); // O: [g_br, DV] x2 ping-pong @@ -70,7 +70,7 @@ static size_t hmx_fa_compute_vtcm_usage(size_t gqa_factor, size_t DK, size_t DV, + k_dma_size * 2 // K DMA x2 + v_dma_size * 2 // V DMA x2 + k_tile_size * 1 // K tiles - + v_tile_size * (use_pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + + v_tile_size * (pipeline ? 2 : 1) // V tiles (double-buffered if pipelining) + s_tile_size * 2 // S + P + d_tile_size * 1 // D (diagonal matrix) + col_vec_size * 4 // m_vec, l_vec, s_rowmax, p_rowsum @@ -290,7 +290,7 @@ static const int16_t d_tile_scatter_offsets[64] __attribute__((aligned(128))) = struct hmx_fa_context { const struct htp_ops_context * octx; - bool use_pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 + bool pipeline; // true when n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads >= 2 uint32_t n_threads; // Op parameters @@ -409,7 +409,7 @@ static void fa_v_interleave_thread(unsigned int n, unsigned int i, void * data) return; } - __fp16 * v_tiles_dest = factx->use_pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; + __fp16 * v_tiles_dest = factx->pipeline ? factx->vtcm_v_tiles[args->buf_idx] : factx->vtcm_v_tiles[0]; struct htp_thread_trace * tr = factx->octx->ctx ? &factx->octx->ctx->trace[i] : NULL; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, start); @@ -1312,13 +1312,13 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t g_br = hex_align_up(G * Br, HMX_FP16_TILE_N_ROWS); const uint32_t n_kv_blocks = (nek1 + Bc - 1) / Bc; - const bool use_pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); + const bool pipeline = (n_kv_blocks >= FA_MIN_KV_BLOCKS && n_threads_init >= 2); // Bypass thread pool dispatch for small prompts/non-pipelined prefill by setting n_threads = 1 - const uint32_t n_threads = use_pipeline ? n_threads_init : 1; + const uint32_t n_threads = pipeline ? n_threads_init : 1; FARF(HIGH, "hmx-fa: neq1=%u nek1=%u DK=%u DV=%u G=%u Br=%zu Bc=%zu g_br=%zu n_kv_blocks=%u pipeline=%d vtcm=%zu", - neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, use_pipeline, vtcm_budget); + neq1, nek1, DK, DV, G, Br, Bc, g_br, n_kv_blocks, pipeline, vtcm_budget); // ======== Build context ======== struct hmx_fa_context factx; @@ -1339,7 +1339,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.n_kv_blocks = n_kv_blocks; factx.is_q_fp32 = (q->type == HTP_TYPE_F32); factx.is_dst_fp32 = (dst->type == HTP_TYPE_F32); - factx.use_pipeline = use_pipeline; + factx.pipeline = pipeline; factx.mask_broadcast = (mask != NULL && mask->ne[2] == 1); // Extract op parameters (mutable during softcap adjustment, then stored as const in factx) @@ -1405,7 +1405,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { factx.vtcm_v_fp16[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_dma_bytes); factx.vtcm_k_tiles = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, k_tile_bytes); factx.vtcm_v_tiles[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); - if (use_pipeline) { + if (pipeline) { factx.vtcm_v_tiles[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_cur, v_tile_bytes); } else { factx.vtcm_v_tiles[1] = NULL; @@ -1456,7 +1456,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { // ======== HMX lock strategy ======== // Pipeline: queue thread auto-acquires HMX lock on first push; released by suspend. // Fallback: main thread holds the lock (original behavior). - if (!factx.use_pipeline) { + if (!factx.pipeline) { HAP_compute_res_hmx_lock(ctx->vtcm_rctx); } @@ -1550,7 +1550,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { const size_t k_src_stride = size_k_row_padded / sizeof(__fp16); const size_t v_src_stride = size_v_row_padded / sizeof(__fp16); - if (factx.use_pipeline) { + if (factx.pipeline) { // ================================================================== // Pipeline path: HVX phases โ€– HMX queue worker // ================================================================== @@ -1780,7 +1780,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { fa_build_d_diag_inv_l(&factx, n_row_tiles, n_row_tiles_g_br); // HMX: O_final = diag(1/l) @ O_prev - if (factx.use_pipeline) { + if (factx.pipeline) { on_job.o_curr = o_tile_curr; on_job.o_prev = o_tile_prev; on_job.d_tiles = factx.vtcm_d_tiles; @@ -1826,7 +1826,7 @@ int hmx_flash_attn_ext(struct htp_ops_context * octx) { } // end KV head loop } // end batch loop - if (factx.use_pipeline) { + if (factx.pipeline) { hmx_queue_suspend(ctx->hmx_queue); } else { HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); diff --git a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c b/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c deleted file mode 100644 index 5c37f24ff0..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-matmul-ops.c +++ /dev/null @@ -1,2080 +0,0 @@ -#pragma clang diagnostic ignored "-Wgnu-zero-variadic-macro-arguments" -#pragma clang diagnostic ignored "-Wunused-function" -#pragma clang diagnostic ignored "-Wunused-variable" -#pragma clang diagnostic ignored "-Wunused-but-set-variable" - -#include -#include -#include -#include -#include - -#include -#include - -#define GGML_COMMON_DECL_C -#include "ggml-common.h" - -#include "hex-dma.h" -#include "hex-fastdiv.h" -#include "worker-pool.h" - -#include "hvx-utils.h" -#include "hvx-dump.h" -#include "htp-ctx.h" -#include "htp-ops.h" - -#include "hmx-ops.h" -#include "hmx-utils.h" -#include "hmx-queue.h" -#include "hex-profile.h" - -#include "vtcm-utils.h" - -static const __fp16 q4_0_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -8, 0, -7, 0, -6, 0, -5, 0, -4, 0, -3, 0, -2, 0, -1, 0, 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, -}; - -static const __fp16 q4_1_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 1, 0, 2, 0, 3, 0, 4, 0, 5, 0, 6, 0, 7, 0, 8, 0, 9, 0, 10, 0, 11, 0, 12, 0, 13, 0, 14, 0, 15, 0, -}; - -// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value -// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 -static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, -}; - -static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { - -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, - 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, -}; - -// Scales per x4x2 logical block: 8 ร— sizeof(__fp16) = 16 bytes -#define HMX_X4X2_SCALES_PER_BLK 8 -#define HMX_X4X2_DBLK_SIZE 16 // 8 * 2 bytes (fp16 scales for Q4_0/Q8_0/IQ4_NL) -#define HMX_X4X2_MXFP4_EBLK_SIZE 8 // 8 * 1 byte (E8M0 scales for MXFP4) - -// Compute the byte stride of one row in x4x2 format. -// Numerically equals ggml_row_size(type, k) when k is 256-aligned, because -// x4x2 packing has the same density as block_q4_0 / block_q8_0. -// Layout per row: [quants: nb*128 (Q4) or nb*256 (Q8)][scales: nb*16 bytes] -// Total per row = nb * (128+16) = 144*nb (Q4) or nb * (256+16) = 272*nb (Q8). -// Callers must ensure k is a multiple of 256 (enforced by proc_hmx_matmul_req). -static inline size_t get_x4x2_row_stride(int weight_type, int k) { - int nb = (k + QK_Q4_0x4x2 - 1) / QK_Q4_0x4x2; - switch (weight_type) { - case HTP_TYPE_Q4_0: - case HTP_TYPE_IQ4_NL: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + HMX_X4X2_DBLK_SIZE); // 144 * nb - case HTP_TYPE_Q4_1: - return (size_t) nb * (QK_Q4_0x4x2 / 2 + 32); // 160 * nb - case HTP_TYPE_Q8_0: - return (size_t) nb * (QK_Q8_0x4x2 + HMX_X4X2_DBLK_SIZE); // 272 * nb - case HTP_TYPE_MXFP4: - return (size_t) nb * (QK_MXFP4x4x2 / 2 + HMX_X4X2_MXFP4_EBLK_SIZE); // 136 * nb - case HTP_TYPE_F16: - return (size_t) k * sizeof(__fp16); - case HTP_TYPE_F32: - return (size_t) k * sizeof(float); - default: - return 0; - } -} - -// --- Overflow-safe arithmetic for VTCM budget calculation --- - -static inline bool hmx_mul_overflow(size_t a, size_t b, size_t *out) { - if (a != 0 && b > SIZE_MAX / a) return true; - *out = a * b; - return false; -} - -static inline bool hmx_add_overflow(size_t a, size_t b, size_t *out) { - if (a > SIZE_MAX - b) return true; - *out = a + b; - return false; -} - -// Search for optimal (mc, nc) chunk sizes within VTCM budget. -// -// VTCM model: nc * per_n_cost + mc * per_m_cost + mc * nc * per_mn_cost + overhead -// -// Minimize ceil(m/mc) * m_block_cost + ceil(n/nc) * n_block_cost. -// All matmul paths repeat weight processing per M-block and activation loading -// per N-block, so discrete block counts drive total overhead. -// Tie-break: when cost is equal, prefer larger mc * nc. -// -// Caller-provided coefficients: -// m_block_cost: penalty per extra M-block (weight redundancy, scales with n). -// n_block_cost: penalty per extra N-block (activation redundancy, scales with m). -// -// Algorithm: nc sweeps from n_max down by 32, analytically solving for mc_max. -// Returns 0 on success, -1 if VTCM is insufficient. -static int hmx_compute_chunks(size_t vtcm_total, - size_t overhead, - size_t per_n_cost, - size_t per_m_cost, - size_t per_mn_cost, - int m, - int n, - size_t m_block_cost, - size_t n_block_cost, - size_t * m_chunk_out, - size_t * n_chunk_out, - size_t * total_out) { - if (m <= 0 || n <= 0) return -1; - if (vtcm_total <= overhead) return -1; - if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; - - const size_t usable = vtcm_total - overhead; - - size_t best_cost = SIZE_MAX; - size_t best_mn = 0; - size_t best_m = 0, best_n = 0; - - const size_t n_max = hex_align_down((size_t)n, HMX_FP16_TILE_N_COLS); - for (size_t nc = n_max; nc >= HMX_FP16_TILE_N_COLS; nc -= HMX_FP16_TILE_N_COLS) { - size_t n_fixed = 0, ncmn = 0, mc_denom = 0; - if (hmx_mul_overflow(nc, per_n_cost, &n_fixed)) continue; - if (n_fixed >= usable) goto next_nc; - - if (hmx_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; - if (hmx_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; - - { - size_t remain = usable - n_fixed; - size_t mc = remain / mc_denom; - mc = hex_align_down(mc, HMX_FP16_TILE_N_ROWS); - mc = hex_smin(mc, (size_t)m); - - if (mc == 0) { - goto next_nc; - } - - size_t mblocks = ((size_t) m + mc - 1) / mc; - size_t nblocks = ((size_t) n + nc - 1) / nc; - size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; - size_t mn = mc * nc; - if (cost < best_cost || (cost == best_cost && mn > best_mn)) { - best_cost = cost; - best_mn = mn; - best_m = mc; - best_n = nc; - } - } - -next_nc: - if (nc == HMX_FP16_TILE_N_COLS) break; // avoid size_t underflow - } - - if (best_m == 0 || best_n == 0) return -1; - - // Compute exact total (with overflow checks) - size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; - if (hmx_mul_overflow(best_n, per_n_cost, &t0)) return -1; - if (hmx_mul_overflow(best_m, per_m_cost, &t1)) return -1; - if (hmx_mul_overflow(best_m, best_n, &mn)) return -1; - if (hmx_mul_overflow(mn, per_mn_cost, &t2)) return -1; - if (hmx_add_overflow(t0, t1, &total)) return -1; - if (hmx_add_overflow(total, t2, &total)) return -1; - if (hmx_add_overflow(total, overhead, &total)) return -1; - - *m_chunk_out = best_m; - *n_chunk_out = best_n; - *total_out = total; - return 0; -} - -// --- x4x2 format dequantizers --- - -// Dequantize one x4x2 Q4_0 group (32 elements from 32 packed bytes) -> 32 FP16 in first 64 bytes. -// In x4x2, sub-blocks 0..3 use lower nibbles, sub-blocks 4..7 use upper nibbles -// of the same 32 packed bytes. -static inline HVX_Vector dequantize_x4x2_q4_0_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_int8)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// Batch-dequantize 4 contiguous x4x2 Q4_0 groups (4x32 = 128 packed bytes) using -// full HVX vector width. -// Output: vector_x2 each hold 32 FP16 values in the first 64 bytes. -static inline HVX_Vector_x2 dequantize_x4x2_q4_0_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_int8 = Q6_Vb_vsub_VbVb(v_quants, i8); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_int8); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -static inline HVX_Vector dequantize_x4x2_q4_1_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale_offset, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_dm = hvx_vmemu(scale_offset); - HVX_Vector v_scales = hvx_vec_repl_f16(v_dm); - HVX_Vector v_offsets = hvx_vec_repl_f16(Q6_V_vror_VR(v_dm, 2)); - - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(v_quants)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales), v_offsets)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_q4_1_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_offsets_4, const HVX_Vector vlut_cvt) { - (void)vlut_cvt; - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_VectorPair vp_int16 = Q6_Wh_vunpack_Vb(v_quants); - HVX_Vector v_lo = Q6_V_lo_W(vp_int16); - HVX_Vector v_hi = Q6_V_hi_W(vp_int16); - - v_lo = Q6_Vhf_equals_Vh(v_lo); - v_hi = Q6_Vhf_equals_Vh(v_hi); - - HVX_Vector vscale_offset = hvx_vmemu(scales_offsets_4); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); - HVX_Vector vd = Q6_V_lo_W(dm_deal); - HVX_Vector vm = Q6_V_hi_W(dm_deal); - - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vd); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vd, 4)); - - HVX_Vector v_os01 = hvx_vec_repl_2x_f16(vm); - HVX_Vector v_os23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vm, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01), v_os01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23), v_os23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// LUT-based dequantizers for non-linear IQ4_NL format. -static inline HVX_Vector dequantize_x4x2_iq4_nl_group_hvx(const uint8_t *packed_32, bool upper_nibbles, const __fp16 *scale, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -static inline HVX_Vector_x2 dequantize_x4x2_iq4_nl_x4groups_hvx( - const uint8_t *packed_128, bool upper_nibbles, - const __fp16 *scales_4, const HVX_Vector vlut_cvt) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = Q6_Vub_vlsr_VubR(vq, 4 * upper_nibbles); - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_Vector vscale = hvx_vmemu(scales_4); - HVX_Vector v_sc01 = hvx_vec_repl_2x_f16(vscale); - HVX_Vector v_sc23 = hvx_vec_repl_2x_f16(Q6_V_vror_VR(vscale, 4)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x2 r = { v_lo, v_hi }; - return r; -} - -// Dequantize one x4x2 Q8_0 group (32 int8 quants) -> 32 FP16 in first 64 bytes. -static inline HVX_Vector dequantize_x4x2_q8_0_group_hvx(const int8_t *quants_32, const __fp16 *scale) { - HVX_Vector vq = hvx_vmemu(quants_32); - HVX_Vector v_scales = hvx_vec_repl_f16(hvx_vmemu(scale)); - HVX_Vector v0 = Q6_V_lo_W(Q6_Wh_vunpack_Vb(vq)); - HVX_Vector v_hf = Q6_Vhf_equals_Vh(v0); - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_scales)); -} - -// --- MXFP4 E8M0 scale conversion and dequantization --- -// -// HVX batch-convert 8 E8M0 bytes (one x4x2 block's scales) to __fp16[8] on stack. -// Scalar loads from the stack array execute on the scalar pipeline, in parallel -// with HVX vlut16/vmpy/vscatter โ€” freeing HVX slots in the hot loop. -// Arithmetic: fp16_bits = clamp(e - 112, 0, 30) << 10 -// e=0..112 -> 0 (underflow), e=113..142 -> valid fp16, e>=143 -> clamped to 2^15. - -typedef struct { - __fp16 v[8] __attribute__((aligned(16))); -} mxfp4_scales_t; - -static inline mxfp4_scales_t mxfp4_convert_scales(const uint8_t * e8m0_8) { - mxfp4_scales_t s; - HVX_Vector v = hvx_vmemu(e8m0_8); - HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); - vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); - vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); - vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); - vh = Q6_Vh_vasl_VhR(vh, 10); - hvx_vec_store_u(s.v, 16, vh); - return s; -} - -static inline HVX_Vector mxfp4_extract_splat(mxfp4_scales_t scales, int idx) { - return hvx_vec_splat_f16(scales.v[idx]); -} - -// Dequantize one x4x2 MXFP4 group (32 elements from 32 packed bytes) -> 32 FP16. -static inline HVX_Vector dequantize_x4x2_mxfp4_group_hvx(const uint8_t * packed_32, - bool upper_nibbles, - int sub_blk, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_32); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - HVX_Vector v_sc = mxfp4_extract_splat(scales, sub_blk); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_hf = Q6_V_lo_W(vp); - - return Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hf, v_sc)); -} - -// Batch-dequantize 4 contiguous x4x2 MXFP4 groups (4x32 = 128 packed bytes). -static inline HVX_Vector_x4 dequantize_x4x2_mxfp4_x4groups_hvx(const uint8_t * packed_128, - bool upper_nibbles, - int sub_blk_base, - const HVX_Vector vlut_cvt, - mxfp4_scales_t scales) { - HVX_Vector vq = hvx_vmemu(packed_128); - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - HVX_Vector v_quants = upper_nibbles ? Q6_Vub_vlsr_VubR(vq, 4) : vq; - v_quants = Q6_V_vand_VV(v_quants, mask_h4); - - v_quants = Q6_Vb_vshuff_Vb(v_quants); - - HVX_VectorPair vp = Q6_Wh_vlut16_VbVhR(v_quants, vlut_cvt, 0); - HVX_Vector v_lo = Q6_V_lo_W(vp); - HVX_Vector v_hi = Q6_V_hi_W(vp); - - HVX_VectorPred q64 = Q6_Q_vsetq_R(64); - HVX_Vector v_sc01 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 0), - mxfp4_extract_splat(scales, sub_blk_base + 1)); - HVX_Vector v_sc23 = Q6_V_vmux_QVV(q64, mxfp4_extract_splat(scales, sub_blk_base + 2), - mxfp4_extract_splat(scales, sub_blk_base + 3)); - - v_lo = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_lo, v_sc01)); - v_hi = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(v_hi, v_sc23)); - - HVX_Vector_x4 r = { v_lo, Q6_V_vror_VR(v_lo, 64), v_hi, Q6_V_vror_VR(v_hi, 64) }; - return r; -} - -typedef struct { - __fp16 *dst; - const uint8_t *src; - int n_cols; - int k_block; - size_t row_stride; - int weight_type; - int n_tot_tiles; - int n_tiles_per_task; - int n_tasks; - int n_k_tiles; - struct fastdiv_values n_k_tiles_div; - struct htp_thread_trace * traces; -} x4x2_dequantize_state_t; - -// Dequantize a tile range from x4x2 weight data (already in VTCM) to tile-major FP16. -// Input: vtcm_src has n_cols rows of x4x2 data, each row_stride bytes. -// Output: vtcm_dst in tile-major FP16 layout. - -#define DEFINE_DEQUANTIZE_Q4_TASK(suffix, lut_name, helper_prefix, dblk_size, scale_step) \ -static void dequantize_x4x2_weight_to_fp16_tiles_task_##suffix( \ - const x4x2_dequantize_state_t *state, \ - int start_tile, int end_tile) { \ - \ - const int n_k_tiles = state->n_k_tiles; \ - const int qrow_size = (unsigned)state->k_block / 2; \ - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; \ - const HVX_Vector vlut_cvt = hvx_vmem(lut_name); \ - \ - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); \ - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); \ - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); \ - \ - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); \ - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); \ - \ - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { \ - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } \ - \ - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { \ - unsigned blk_idx = ((kt * 32) / QK_Q4_0x4x2); \ - unsigned sub_blk_base = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk_base >= 4); \ - unsigned packed_off = blk_idx * (QK_Q4_0x4x2 / 2); \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk_base * (scale_step); \ - \ - __fp16 *tile_bases[4]; \ - for (unsigned g = 0; g < 4; g++) { \ - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; \ - } \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector_x2 dv0 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r0 + packed_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - \ - HVX_Vector_x2 dv1 = dequantize_x4x2_##helper_prefix##_x4groups_hvx( \ - r1 + packed_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt); \ - Q6_vscatter_RMVwV((size_t)tile_bases[0], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[0]); \ - Q6_vscatter_RMVwV((size_t)tile_bases[2], 2 * HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[1]); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - \ - for (int g = 0; g < 4; g++) { (void) *(volatile HVX_Vector *)(tile_bases[g]); } \ - t += 4; kt += 4; \ - continue; \ - } \ - \ - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; \ - { \ - unsigned blk_idx = (kt * 32) / QK_Q4_0x4x2; \ - unsigned sub_blk = ((kt * 32) % QK_Q4_0x4x2) / 32; \ - bool upper = (sub_blk >= 4); \ - unsigned byte_off = blk_idx * (QK_Q4_0x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; \ - unsigned scale_off = qrow_size + blk_idx * (dblk_size) + sub_blk * (scale_step); \ - \ - HVX_Vector v_off = v_scat_base; \ - unsigned row_offset = ct * HMX_FP16_TILE_N_COLS * state->row_stride; \ - unsigned row1 = ct * HMX_FP16_TILE_N_COLS + 1; \ - \ - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2, row1 += 2) { \ - const uint8_t *r0 = state->src + row_offset; row_offset += state->row_stride; \ - const uint8_t *r1 = state->src + row_offset; row_offset += state->row_stride; \ - \ - HVX_Vector v0 = dequantize_x4x2_##helper_prefix##_group_hvx( \ - r0 + byte_off, upper, (const __fp16 *)(r0 + scale_off), vlut_cvt); \ - HVX_Vector v1 = (row1 < (unsigned)state->n_cols) \ - ? dequantize_x4x2_##helper_prefix##_group_hvx( \ - r1 + byte_off, upper, (const __fp16 *)(r1 + scale_off), vlut_cvt) \ - : Q6_V_vzero(); \ - \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); \ - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); \ - } \ - (void) *(volatile HVX_Vector *)(tile_base); \ - } \ - ++t; ++kt; \ - } \ - \ - if (start_tile < end_tile) { \ - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); \ - } \ -} \ - \ -static void dequantize_x4x2_worker_loop_##suffix(unsigned int n, unsigned int i, void *data) { \ - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; \ - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ - int start = task_id * state->n_tiles_per_task; \ - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ - dequantize_x4x2_weight_to_fp16_tiles_task_##suffix(state, start, end); \ - } \ - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ -} - -DEFINE_DEQUANTIZE_Q4_TASK(q4_0, q4_0_to_fp16_lut, q4_0, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) -DEFINE_DEQUANTIZE_Q4_TASK(q4_1, q4_1_to_fp16_lut, q4_1, 32, 4) -DEFINE_DEQUANTIZE_Q4_TASK(iq4_nl, iq4_nl_to_fp16_lut, iq4_nl, HMX_X4X2_DBLK_SIZE, (int)sizeof(__fp16)) - -static void dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = (unsigned)state->k_block / 2; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - // Batch-4 fast path for MXFP4 - if ((kt % 4 == 0) && (t + 4 <= (unsigned)end_tile) && (fastdiv(t + 3, &n_k_tiles_div) == ct)) { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk_base = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk_base >= 4); - int packed_off = blk_idx * (QK_MXFP4x4x2 / 2); - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - __fp16 * tile_bases[4]; - for (int g = 0; g < 4; g++) { - tile_bases[g] = state->dst + (t + g) * HMX_FP16_TILE_N_ELMS; - } - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector_x4 dv0, dv1; - dv0 = dequantize_x4x2_mxfp4_x4groups_hvx(r0 + packed_off, upper, sub_blk_base, vlut_cvt, r0_e8); - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - dv1 = dequantize_x4x2_mxfp4_x4groups_hvx(r1 + packed_off, upper, sub_blk_base, vlut_cvt, r1_e8); - } else { - dv1.v[0] = dv1.v[1] = dv1.v[2] = dv1.v[3] = Q6_V_vzero(); - } - - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv0.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - for (int g = 0; g < 4; g++) { - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_bases[g], HMX_FP16_TILE_SIZE - 1, v_off, dv1.v[g]); - } - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - - for (int g = 0; g < 4; g++) { - (void) *(volatile HVX_Vector *) (tile_bases[g]); - } - - t += 4; kt += 4; - continue; - } - - // Single-tile fallback - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_MXFP4x4x2; - int sub_blk = ((kt * 32) % QK_MXFP4x4x2) / 32; - bool upper = (sub_blk >= 4); - int byte_off = blk_idx * (QK_MXFP4x4x2 / 2) + (upper ? (sub_blk - 4) : sub_blk) * 32; - int e8m0_blk_off = qrow_size + blk_idx * HMX_X4X2_MXFP4_EBLK_SIZE; - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t * r0 = state->src + row0 * state->row_stride; - const uint8_t * r1 = state->src + row1 * state->row_stride; - - mxfp4_scales_t r0_e8 = mxfp4_convert_scales(r0 + e8m0_blk_off); - - HVX_Vector v0 = dequantize_x4x2_mxfp4_group_hvx(r0 + byte_off, upper, sub_blk, vlut_cvt, r0_e8); - HVX_Vector v1; - if (row1 < state->n_cols) { - mxfp4_scales_t r1_e8 = mxfp4_convert_scales(r1 + e8m0_blk_off); - v1 = dequantize_x4x2_mxfp4_group_hvx(r1 + byte_off, upper, sub_blk, vlut_cvt, r1_e8); - } else { - v1 = Q6_V_vzero(); - } - - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t) tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *) (tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_mxfp4(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_mxfp4(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void dequantize_x4x2_weight_to_fp16_tiles_task_q8_0( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const int qrow_size = state->k_block; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int blk_idx = (kt * 32) / QK_Q8_0x4x2; - int sub_blk = ((kt * 32) % QK_Q8_0x4x2) / 32; - int byte_off = blk_idx * QK_Q8_0x4x2 + sub_blk * 32; - int scale_off = qrow_size + blk_idx * HMX_X4X2_DBLK_SIZE + sub_blk * (int)sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r0 + byte_off), (const __fp16 *)(r0 + scale_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? dequantize_x4x2_q8_0_group_hvx((const int8_t *)(r1 + byte_off), (const __fp16 *)(r1 + scale_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void dequantize_x4x2_worker_loop_q8_0(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - dequantize_x4x2_weight_to_fp16_tiles_task_q8_0(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void convert_f16_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(__fp16); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); - HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v0); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v1); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - convert_f16_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - -static void quantize_f32_weight_to_fp16_tiles_task( - const x4x2_dequantize_state_t *state, - int start_tile, int end_tile) { - - const int n_k_tiles = state->n_k_tiles; - const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; - - const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); - const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); - const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); - - unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); - unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); - - for (unsigned t = start_tile; t < (unsigned)end_tile; ) { - if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } - - __fp16 *tile_base = state->dst + t * HMX_FP16_TILE_N_ELMS; - { - int byte_off = kt * 32 * sizeof(float); - - HVX_Vector v_off = v_scat_base; - for (int r = 0; r < HMX_FP16_TILE_N_ROWS; r += 2) { - int row0 = ct * HMX_FP16_TILE_N_COLS + r; - int row1 = row0 + 1; - - const uint8_t *r0 = state->src + row0 * state->row_stride; - const uint8_t *r1 = state->src + row1 * state->row_stride; - - HVX_Vector v0_f32 = hvx_vmemu((const float *)(r0 + byte_off)); - HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmemu((const float *)(r1 + byte_off)) : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); - - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - - HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); - Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HMX_FP16_TILE_SIZE - 1, v_off, v_out_hi); - v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); - } - (void) *(volatile HVX_Vector *)(tile_base); - } - ++t; ++kt; - } - - if (start_tile < end_tile) { - (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HMX_FP16_TILE_N_ELMS); - } -} - -static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { - x4x2_dequantize_state_t *state = (x4x2_dequantize_state_t *)data; - struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); - for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { - int start = task_id * state->n_tiles_per_task; - int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); - quantize_f32_weight_to_fp16_tiles_task(state, start, end); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); -} - - -static void dequantize_x4x2_weight_chunk_to_fp16_tiles( - struct htp_context *ctx, __fp16 *vtcm_dst, - const void *vtcm_src, int n_cols, int k_block, - size_t row_stride, int weight_type, - int n_k_tiles, struct fastdiv_values n_k_tiles_div, - worker_callback_t dequant_worker_fn, int n_threads) { - - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - assert(k_block % HMX_FP16_TILE_N_COLS == 0); - - size_t n_col_tiles = n_cols / HMX_FP16_TILE_N_COLS; - size_t n_tot_tiles = n_col_tiles * n_k_tiles; - - size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); - - x4x2_dequantize_state_t state; - state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; - state.n_tot_tiles = n_tot_tiles; - state.n_tiles_per_task = n_tiles_per_task; - state.dst = vtcm_dst; - state.src = (const uint8_t *)vtcm_src; - state.n_cols = n_cols; - state.k_block = k_block; - state.row_stride = row_stride; - state.weight_type = weight_type; - state.n_k_tiles = n_k_tiles; - state.n_k_tiles_div = n_k_tiles_div; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - dequant_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_threads); - } -} - -// --- End x4x2 dequantizers --- - -#pragma clang diagnostic ignored "-Wbackend-plugin" // spurios warning for hmx intrinsics - -// requires external HMX lock -static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, - int n_row_tiles, int n_col_tiles, int n_dot_tiles) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)scales); - for (int r = 0; r < n_row_tiles; ++r) { - for (size_t c = 0; c < n_col_tiles; ++c) { - Q6_mxclracc_hf(); - - const __fp16 *row_tiles = activation + r * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - const __fp16 *col_tiles = weight + c * n_dot_tiles * HMX_FP16_TILE_N_ELMS; - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - __fp16 *out_tile = output + (r * n_col_tiles + c) * HMX_FP16_TILE_N_ELMS; - Q6_mxmem_AR_after_hf(out_tile, 0); - } - } -} - -// --- Async HMX matmul job (for pipeline overlap) --- - -typedef struct { - __fp16 * output; - const __fp16 * activation; - const __fp16 * weight; - const __fp16 * scales; - uint32_t n_row_tiles; - uint32_t n_col_tiles; - uint32_t n_dot_tiles; -} hmx_matmul_job_t; - -static void hmx_matmul_worker_fn(void * data) { - hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; - FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); - core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); -} - -static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, - __fp16 * output, - const __fp16 * activation, - const __fp16 * weight, - const __fp16 * scales, - int n_row_tiles, - int n_col_tiles, - int n_dot_tiles) { - job->output = output; - job->activation = activation; - job->weight = weight; - job->scales = scales; - job->n_row_tiles = n_row_tiles; - job->n_col_tiles = n_col_tiles; - job->n_dot_tiles = n_dot_tiles; -} - -// output : fp16 -> f32p - -static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, int n_rows, int n_cols, int n) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - float *output_row_base = dst + r * n; // global memory row base for row r (and r+1) - - #pragma unroll(4) - for (size_t c = 0; c < n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row_base + c + 0); - volatile HVX_Vector *pv_out1 = (volatile HVX_Vector *) (output_row_base + c + n); // next row in global memory - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (r + 1 < n_rows) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - int n; // DDR row stride (total output columns) - struct htp_thread_trace * traces; -} output_transfer_task_state_t; - -static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - float *dst = st->dst + chunk_idx * st->n; - const __fp16 *vtcm_src = st->vtcm_src + chunk_idx * st->n_cols; - transfer_output_chunk_fp16_to_fp32(dst, vtcm_src, chunk_size, st->n_cols, st->n); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, - int n_rows, int n_cols, int n, int n_threads) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : HMX_FP16_TILE_N_ROWS; // must be multiple of HMX_FP16_TILE_N_ROWS (32) - - output_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.vtcm_src = vtcm_src; - state.n_cols = n_cols; - state.n = n; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_output_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, n_threads); - } -} - -// activations : fp32 -> fp16 - -static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, int n_rows, int k_block, int k_stride) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const HVX_Vector *pv_in0 = (const HVX_Vector *) (src + (r + 0) * k_stride); - const HVX_Vector *pv_in1 = (const HVX_Vector *) (src + (r + 1) * k_stride); - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = r < n_rows; - const bool row1_valid = (r + 1) < n_rows; - - const HVX_Vector *pv_in0 = row0_valid ? (const HVX_Vector *) (src + (r + 0) * k_stride) : NULL; - const HVX_Vector *pv_in1 = row1_valid ? (const HVX_Vector *) (src + (r + 1) * k_stride) : NULL; - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - // compute output position - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - int k_stride; - struct htp_thread_trace * traces; -} activation_transfer_task_state_t; - -static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { - // one chunk: one row - int chunk_idx = task_id * st->n_chunks_per_task; - size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); - - __fp16 *dst = st->dst + chunk_idx * st->k_block; - const float *src = st->src + chunk_idx * st->k_stride; - transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_threaded(struct htp_context *ctx, __fp16 *dst, const float *src, int n_rows, int k_block, int k_stride, int n_threads) { - assert(k_block % HMX_FP16_TILE_N_COLS == 0 && k_stride % HMX_FP16_TILE_N_COLS == 0); - assert(VLEN == 32 * sizeof(float)); - - size_t n_tot_chunks = n_rows; - size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address - - activation_transfer_task_state_t state; - state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; - state.n_tot_chunks = n_tot_chunks; - state.n_chunks_per_task = n_chunks_per_task; - state.dst = dst; - state.src = src; - state.k_block = k_block; - state.k_stride = k_stride; - state.traces = ctx ? ctx->trace : NULL; - - if (state.n_tasks == 1 || n_threads == 1) { - transfer_activation_chunk_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_threads); - } -} - -// C += AB -static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, - const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, - int n_row_tiles, int n_col_tiles, int n_dot_tiles, bool zero_init) { - __builtin_assume(n_row_tiles > 0); - __builtin_assume(n_col_tiles > 0); - __builtin_assume(n_dot_tiles > 0); - - Q6_bias_mxmem2_A((void *)col_scales); - - const size_t dot_tile_stride = n_dot_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t i = 0; i < n_row_tiles; ++i) { - const __fp16 *row_base = a + i * dot_tile_stride; - __fp16 *res_base = c + i * n_col_tiles * HMX_FP16_TILE_N_ELMS; - for (size_t j = 0; j < n_col_tiles; ++j) { - Q6_mxclracc_hf(); - - const __fp16 *col_tiles = b + j * dot_tile_stride; - const __fp16 *row_tiles = row_base; - __fp16 *accum_tile = res_base + j * HMX_FP16_TILE_N_ELMS; - if (!zero_init) { - Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); - Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); - } - - for (int k = 0, k_block; k < n_dot_tiles; k += k_block) { - k_block = hex_smin(n_dot_tiles - k, 32); - const uint32_t range = 2048u * (uint32_t)k_block - 1; - Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); - Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); - row_tiles += k_block * HMX_FP16_TILE_N_ELMS; - col_tiles += k_block * HMX_FP16_TILE_N_ELMS; - } - - Q6_mxmem_AR_after_hf(accum_tile, 0); - } - } -} - -int hmx_matmul_2d_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const uint8_t *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride, int weight_type) { - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - // --- Dynamic Mode Configuration --- - const bool use_pipeline = (m > 32); - const int num_threads = (m <= 32) ? 1 : ctx->n_threads; - - // --- Dynamic VTCM layout --- - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - // Pipeline = 4-stage DMAโ†’dequantโ†’HMXโ†’store with HMX worker overlap. - const size_t size_per_n = row_stride + (use_pipeline ? 2 * vec_dot_size : vec_dot_size); // Q + S0 + S1 (dequant bufs) - const size_t size_per_mn = (use_pipeline ? 2 : 1) * sizeof(__fp16); // O x 2 (output double buffer) - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - hex_align_up(m, HMX_FP16_TILE_N_ROWS), n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-2d: VTCM too small : m %d k %d n %d budget %zu", m, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size, scratch1_size, scratch2_size; - scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); // dequant buf 0 - scratch1_size = use_pipeline ? scratch0_size : 0; // dequant buf 1 - scratch2_size = use_pipeline ? output_area_size : 0; // output buf 1 - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; - void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "hmx-mm-2d: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", - m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); - - - - int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); - - if (use_pipeline) { - // --- Asynchronous Pipelined Loop (Current implementation) --- - hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - - void *vtcm_qweight = vtcm_weight; - void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; - void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; - - // prologue: A0 - const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); - { - const uint8_t *qweight_chunk_A0 = permuted_weight; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A0), row_stride, weight_stride, row_stride, n_cols_A0); - } - - { - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - } - - // prologue: B0, A1, submit C0 (async), B1 (overlaps C0) - { - // B0: wait for DMA, dequant weight chunk 0 - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[0], vtcm_qweight, n_cols_A0, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // A1: issue DMA for weight chunk 1 - const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); - if (1 < n_chunk_cnt) { - const uint8_t *qweight_chunk_A1 = permuted_weight + n_chunk_n_cols * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_A1), row_stride, weight_stride, row_stride, n_cols_A1); - } - - // submit C0 (non-blocking โ€” HMX worker executes in parallel) - hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_activation, - (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, - hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_A0, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); - - // B1: DMA pop + dequant (runs in parallel with C0 on HMX worker) - if (1 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[1], vtcm_qweight, n_cols_A1, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - - // main loop: wait C_i โ†’ submit C_{i+1} โ†’ D_i + B_{i+2} (parallel with C_{i+1}) - for (int i = 0; i < n_chunk_cnt; ++i) { - const size_t nc = i * n_chunk_n_cols; - const size_t nc_p1 = nc + 1 * n_chunk_n_cols; - const size_t nc_p2 = nc + 2 * n_chunk_n_cols; - - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); - const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); - - // issue A_{i+2}: DMA push (non-blocking) - if (i + 2 < n_chunk_cnt) { - const uint8_t *qweight_chunk_p2 = permuted_weight + nc_p2 * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_qweight, qweight_chunk_p2), row_stride, weight_stride, row_stride, n_cols_p2); - } - - // wait C_i: block until prologue/previous C completes - hmx_queue_pop(ctx->hmx_queue); - - // submit C_{i+1} (non-blocking, overlaps with D_i + B_{i+2} below) - if (i + 1 < n_chunk_cnt) { - hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], - (__fp16 *) vtcm_activation, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], - vtcm_scales, hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS), - hmx_ceil_div(n_cols_p1, HMX_FP16_TILE_N_COLS), k / HMX_FP16_TILE_N_ROWS); - hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); - } - - // D_i: store output (multi-thread HVX, parallel with C_{i+1}) - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, n, num_threads); - - // B_{i+2}: DMA pop + dequant (multi-thread HVX, parallel with C_{i+1}) - if (i + 2 < n_chunk_cnt) { - dma_queue_pop(ctx->dma[0]); - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_weight_bufs[(i + 2) % 2], vtcm_qweight, n_cols_p2, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - } - } - } - hmx_queue_suspend(ctx->hmx_queue); - } else { - // --- Synchronous Loop (Optimized for small/non-pipelined cases) --- - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - // Load Activation - const float *activation_chunk = activation + mr * act_stride; - transfer_activation_chunk_threaded(ctx, vtcm_activation, activation_chunk, n_rows, k, act_stride, num_threads); - - for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - // A: DMA Load Weight - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - // B: Dequantize / Convert Weight - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - // C: HMX Compute (Synchronous) - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - // D: Output Store - float *output_chunk = dst + (mr * n + nc); - transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, n, num_threads); - } - } - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - } - - - - return 0; -} - -// - -static inline int hmx_matmul_batch_r2(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; -} - -static inline int hmx_matmul_batch_r3(const hmx_matmul_f16_f32_batched_params_t *params) { - return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; -} - -static inline const __fp16 *hmx_matmul_weight_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - const int r2 = hmx_matmul_batch_r2(params); - const int r3 = hmx_matmul_batch_r3(params); - return (const __fp16 *) ((const uint8_t *) params->permuted_weight + - (size_t) (dst_b2 / r2) * params->src0_nb2 + - (size_t) (dst_b3 / r3) * params->src0_nb3); -} - -static inline const float *hmx_matmul_activation_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (const float *) ((const uint8_t *) params->activation + - (size_t) dst_b2 * params->src1_nb2 + - (size_t) dst_b3 * params->src1_nb3); -} - -static inline float *hmx_matmul_dst_batch_ptr(const hmx_matmul_f16_f32_batched_params_t *params, - int dst_b2, int dst_b3) { - return (float *) ((uint8_t *) params->dst + - (size_t) dst_b2 * params->dst_nb2 + - (size_t) dst_b3 * params->dst_nb3); -} - -static int hmx_matmul_f16_f32_batched_legacy(struct htp_context *ctx, - const hmx_matmul_f16_f32_batched_params_t *params) { - int ret = 0; - for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { - for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { - ret = hmx_matmul_f16_f32(ctx, hmx_matmul_dst_batch_ptr(params, b2, b3), - hmx_matmul_activation_batch_ptr(params, b2, b3), - hmx_matmul_weight_batch_ptr(params, b2, b3), - params->m, params->k, params->n, - params->act_stride, params->weight_stride); - } - } - return ret; -} - -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params) { - if (!ctx || !params || !params->dst || !params->activation || !params->permuted_weight) { return -1; } - if (!params->m || !params->k || !params->n) { return -1; } - if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } - if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } - if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } - if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } - - if (!hex_is_aligned(params->dst, VLEN) || - !hex_is_aligned(params->activation, VLEN) || - !hex_is_aligned(params->permuted_weight, VLEN)) { - return -1; - } - - const int group_size = hmx_matmul_batch_r2(params); - - if (group_size <= 1) { - FARF(HIGH, "%s: no dim2 GQA reuse (group=%d), using legacy batched loop", __func__, group_size); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - // Grouped path: reuse interleaved weight across all q_heads sharing a - // kv_head. Each q_head gets its own activation buffer in VTCM (so - // activation is loaded once per m_chunk and reused across all n_chunks), - // and each q_head is computed individually to avoid tile-major packing - // issues. m_chunk_n_rows is always a multiple of 32 (from - // hmx_compute_chunks), so per-head tile arrays don't overlap. - const size_t vtcm_budget = ctx->vtcm_size; - const size_t vec_dot_size = params->k * sizeof(__fp16); - - // When the activation has a large stride (e.g. permuted Q tensor with - // act_stride >> k), HVX vector loads from strided DDR thrash L2 cache. - // Allocate an F32 scratch buffer in VTCM and use 2D DMA to gather - // strided rows into a contiguous block before the F32->F16 conversion. - const bool use_dma_activation = (params->act_stride > params->k); - const size_t f32_scratch_per_m = use_dma_activation ? (size_t) params->k * sizeof(float) : 0; - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0, vtcm_used = 0; - // FP16 weight: interleave and activation load have similar per-element cost. - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, - /*per_n=*/3 * vec_dot_size, - /*per_m=*/group_size * vec_dot_size + f32_scratch_per_m, - /*per_mn=*/sizeof(__fp16), - hex_align_up(params->m, HMX_FP16_TILE_N_ROWS), params->n, - /*m_block_cost=*/(size_t) params->n, - /*n_block_cost=*/(size_t) params->m, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used) != 0) { - FARF(HIGH, "%s: grouped path does not fit VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t f32_scratch_size = use_dma_activation - ? hex_align_up(m_chunk_n_rows * (size_t) params->k * sizeof(float), HMX_FP16_TILE_SIZE) : 0; - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; - - if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { - FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to legacy batched loop", __func__); - return hmx_matmul_f16_f32_batched_legacy(ctx, params); - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 - - FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", - __func__, params->m, params->k, params->n, group_size, params->ne13, - m_chunk_n_rows, n_chunk_n_cols, - (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); - - - - const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); - const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (int b3 = 0; b3 < params->ne13; ++b3) { - for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { - const __fp16 *weight_group = hmx_matmul_weight_batch_ptr(params, b2_base, b3); - - for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HMX_FP16_TILE_N_ROWS); - - // Pre-load activations for all heads in the group (once per m_chunk). - // When the source is strided (permuted Q), use 2D DMA to gather - // contiguous rows into a VTCM scratch buffer first, then HVX - // converts from the contiguous VTCM buffer. This avoids L2 cache - // thrashing from HVX loads at large strides. - for (int g = 0; g < group_size; ++g) { - const float *activation_chunk = hmx_matmul_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; - __fp16 *vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - if (use_dma_activation) { - const size_t row_bytes = (size_t) params->k * sizeof(float); - const size_t stride_bytes = (size_t) params->act_stride * sizeof(float); - dma_queue_push(ctx->dma[0], - dma_make_ptr(vtcm_f32_act, activation_chunk), - row_bytes, stride_bytes, row_bytes, n_rows); - dma_queue_pop(ctx->dma[0]); - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - vtcm_f32_act, (int) n_rows, - params->k, params->k, ctx->n_threads); - } else { - transfer_activation_chunk_threaded(ctx, vtcm_act_g, - activation_chunk, (int) n_rows, - params->k, params->act_stride, ctx->n_threads); - } - } - - void *buf_curr = vtcm_scratch0; - void *buf_next = vtcm_scratch1; - - { - const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); - } - - for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HMX_FP16_TILE_N_COLS); - - { - dma_queue_pop(ctx->dma[0]); - - const size_t nc_next = nc + n_chunk_n_cols; - if (nc_next < (size_t) params->n) { - const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); - const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; - - dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), - fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); - } - - hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, - 0, n_cols); - hex_swap_ptr(&buf_curr, &buf_next); - } - - // Reuse the interleaved weight for every q_head in this GQA group - for (int g = 0; g < group_size; ++g) { - { - const __fp16 * vtcm_act_g = vtcm_activation + (size_t) g * act_head_stride; - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, - params->k / 32); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - { - float *output = hmx_matmul_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; - transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, ctx->n_threads); - } - } - } - } - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - - - - return 0; -} - -int hmx_matmul_f16_f32(struct htp_context *ctx, float *restrict dst, const float *restrict activation, - const __fp16 *restrict permuted_weight, int m, int k, int n, - int act_stride, int weight_stride) { - if (!dst || !activation || !permuted_weight || !m || !n || !k) { return -1; } - return hmx_matmul_2d_f32(ctx, dst, activation, (const uint8_t *)permuted_weight, m, k, n, - act_stride, weight_stride * (int)sizeof(__fp16), HTP_TYPE_F16); -} - -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -typedef struct { - __fp16 *dst; - const float *src; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int k_block; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - int ne11; - struct fastdiv_values ne11_div; - size_t nb11; - size_t nb12; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} activation_transfer_gathered_task_state_t; - -typedef struct { - const __fp16 *vtcm_src; - float *dst; - int n_tasks; - int n_tot_chunks; - int n_chunks_per_task; - int n_cols; - const struct mmid_row_mapping *matrix_rows; - int cur_a; - int mapping_stride; - size_t dst_nb1; - size_t dst_nb2; - int start_row; - int cne1; - struct htp_thread_trace *traces; -} output_transfer_scattered_task_state_t; - -static void transfer_activation_chunk_fp32_to_fp16_gathered( - __fp16 *restrict vtcm_dst, - const float *restrict src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - const struct fastdiv_values * ne11_div, - size_t nb11, - size_t nb12, - int cne1) { - const int n_rows_padded = hex_align_up(n_rows, HMX_FP16_TILE_N_ROWS); - const int n_rows_tiled = (n_rows / HMX_FP16_TILE_N_ROWS) * HMX_FP16_TILE_N_ROWS; - - int r = 0; - - #pragma unroll(2) - for (r = 0; r < n_rows_tiled; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - int r_idx0 = start_row + r + 0; - int r_idx1 = start_row + r + 1; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - - const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = *pv_in0++; - HVX_Vector v1 = *pv_in1++; - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } - - for (; r < n_rows_padded; r += 2) { - int r0 = r / HMX_FP16_TILE_N_ROWS; // tile row index - int r1 = r % HMX_FP16_TILE_N_ROWS; // intra-tile row idx - - const bool row0_valid = (start_row + r + 0) < cne1; - const bool row1_valid = (start_row + r + 1) < cne1; - - const float *row0_ptr = NULL; - const float *row1_ptr = NULL; - - if (row0_valid) { - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; - int i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); - row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); - } - if (row1_valid) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; - int i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); - row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); - } - - const HVX_Vector *pv_in0 = (const HVX_Vector *) row0_ptr; - const HVX_Vector *pv_in1 = (const HVX_Vector *) row1_ptr; - - for (int c = 0; c < k_block; c += 32) { - HVX_Vector v0 = row0_valid ? *pv_in0++ : Q6_V_vzero(); - HVX_Vector v1 = row1_valid ? *pv_in1++ : Q6_V_vzero(); - - HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); - - int c0 = c / HMX_FP16_TILE_N_COLS; // tile column index - int tile_idx = r0 * (k_block / HMX_FP16_TILE_N_COLS) + c0; - - HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HMX_FP16_TILE_N_ELMS); - tile[r1 / 2] = v_out; - } - } -} - -static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { - activation_transfer_gathered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - __fp16 *dst = st->dst + (size_t)(start_row - st->start_row) * st->k_block; - transfer_activation_chunk_fp32_to_fp16_gathered( - dst, st->src, start_row, n_rows, st->k_block, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, i); -} - -static void transfer_activation_chunk_gathered_threaded( - struct htp_context *ctx, - __fp16 *dst, - const float *src, - int start_row, - int n_rows, - int k_block, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - int ne11, - size_t nb11, - size_t nb12, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - activation_transfer_gathered_task_state_t state = { - .dst = dst, - .src = src, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .k_block = k_block, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .ne11 = ne11, - .ne11_div = init_fastdiv_values(ne11), - .nb11 = nb11, - .nb12 = nb12, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_activation_chunk_gathered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_gathered_worker_fn, &state, actual_threads); - } -} - -static void transfer_output_chunk_fp16_to_fp32_scattered( - float *restrict dst, - const __fp16 *restrict vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1) { - assert(n_cols % HMX_FP16_TILE_N_COLS == 0); - const size_t tile_row_stride = (n_cols / HMX_FP16_TILE_N_COLS) * HMX_FP16_TILE_N_ELMS; - - const HVX_Vector one = hvx_vec_splat_f16(1.0); - - for (size_t r = 0; r < n_rows; r += 2) { - const size_t r0 = r / HMX_FP16_TILE_N_ROWS; - const size_t r1 = (r % HMX_FP16_TILE_N_ROWS) / 2; // index of the row pair within the tile - const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; - - int r_idx0 = start_row + (int)r + 0; - int r_idx1 = start_row + (int)r + 1; - - if (r_idx0 >= cne1) break; - - struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; - float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); - - float *output_row1 = NULL; - if (r_idx1 < cne1) { - struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; - output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); - } - - #pragma unroll(4) - for (size_t c = 0; c < (size_t)n_cols; c += HMX_FP16_TILE_N_COLS) { - const size_t c0 = c / HMX_FP16_TILE_N_COLS; - const __fp16 *tile = row_base + c0 * HMX_FP16_TILE_N_ELMS; - HVX_Vector v = ((const HVX_Vector *) tile)[r1]; - HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); - - volatile HVX_Vector *pv_out0 = (volatile HVX_Vector *) (output_row0 + c); - volatile HVX_Vector *pv_out1 = output_row1 ? (volatile HVX_Vector *) (output_row1 + c) : NULL; - - *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); - if (pv_out1) { - *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); - } - } - } -} - -static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { - output_transfer_scattered_task_state_t *st = data; - struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, i); - - int chunk_idx = i; - int chunk_size = st->n_chunks_per_task; - int start_row = st->start_row + chunk_idx * chunk_size; - int n_rows = hex_smin(st->cne1 - start_row, chunk_size); - if (n_rows > 0) { - const __fp16 *src = st->vtcm_src + (size_t)(start_row - st->start_row) * st->n_cols; - transfer_output_chunk_fp16_to_fp32_scattered( - st->dst, src, start_row, n_rows, st->n_cols, - st->matrix_rows, st->cur_a, st->mapping_stride, - st->dst_nb1, st->dst_nb2, st->cne1); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, i); -} - -static void transfer_output_chunk_scattered_threaded( - struct htp_context *ctx, - float *dst, - const __fp16 *vtcm_src, - int start_row, - int n_rows, - int n_cols, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride, - size_t dst_nb1, - size_t dst_nb2, - int cne1, - int n_threads) { - if (n_rows <= 0) return; - int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); - chunks_per_thread = hex_align_up(chunks_per_thread, HMX_FP16_TILE_N_ROWS); - - int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); - - output_transfer_scattered_task_state_t state = { - .vtcm_src = vtcm_src, - .dst = dst, - .n_tasks = actual_threads, - .n_tot_chunks = n_rows, - .n_chunks_per_task = chunks_per_thread, - .n_cols = n_cols, - .matrix_rows = matrix_rows, - .cur_a = cur_a, - .mapping_stride = mapping_stride, - .dst_nb1 = dst_nb1, - .dst_nb2 = dst_nb2, - .start_row = start_row, - .cne1 = cne1, - .traces = ctx ? ctx->trace : NULL, - }; - - if (actual_threads <= 1) { - transfer_output_chunk_scattered_worker_fn(1, 0, &state); - } else { - worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); - } -} - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride) { - const int cne1 = m; - const int m_padded = hex_align_up(m, 32); - - if (k % 32 != 0 || n % 32 != 0) { return -1; } - - if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN) || !hex_is_aligned(permuted_weight, VLEN)) { - return -1; - } - - size_t row_stride = get_x4x2_row_stride(weight_type, k); - if (row_stride == 0) { - return -1; - } - - worker_callback_t dequant_worker_fn = NULL; - switch (weight_type) { - case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_0; break; - case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_x4x2_worker_loop_iq4_nl; break; - case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_x4x2_worker_loop_q4_1; break; - case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_x4x2_worker_loop_mxfp4; break; - case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_x4x2_worker_loop_q8_0; break; - case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; - case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; - default: - return -1; - } - - const int n_k_tiles = k / HMX_FP16_TILE_N_COLS; - const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); - - const int num_threads = ctx->n_threads; - - const size_t vec_dot_size = k * sizeof(__fp16); - const size_t vtcm_budget = ctx->vtcm_size; - size_t vtcm_used = 0; - - const size_t size_per_n = row_stride + vec_dot_size; - const size_t size_per_mn = sizeof(__fp16); - - size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; - if (hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, /*per_m=*/vec_dot_size, size_per_mn, - m_padded, n, - /*m_block_cost=*/(size_t) n * 3, - /*n_block_cost=*/(size_t) m_padded * 2, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { - FARF(HIGH, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); - return -1; - } - - const size_t weight_area_size = hex_align_up(n_chunk_n_cols * row_stride, HMX_FP16_TILE_SIZE); - const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HMX_FP16_TILE_SIZE); - const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HMX_FP16_TILE_SIZE); - - size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HMX_FP16_TILE_SIZE); - - uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; - __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); - __fp16 *vtcm_activation = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); - __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); - void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); - __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); - - vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; - if (vtcm_used > vtcm_budget) { - FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); - return -1; - } - - hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); - - HAP_compute_res_hmx_lock(ctx->vtcm_rctx); - - for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { - const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); - const size_t n_row_tiles = hmx_ceil_div(n_rows, HMX_FP16_TILE_N_ROWS); - - transfer_activation_chunk_gathered_threaded( - ctx, vtcm_activation, activation, (int) mr, (int) n_rows, k, - matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, num_threads); - - for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { - const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); - const size_t n_col_tiles = hmx_ceil_div(n_cols, HMX_FP16_TILE_N_COLS); - - const uint8_t *qweight_chunk = permuted_weight + nc * weight_stride; - dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, qweight_chunk), row_stride, weight_stride, row_stride, n_cols); - dma_queue_pop(ctx->dma[0]); - - dequantize_x4x2_weight_chunk_to_fp16_tiles(ctx, vtcm_scratch0, vtcm_weight, n_cols, k, row_stride, weight_type, n_k_tiles, n_k_tiles_div, dequant_worker_fn, num_threads); - - { - struct htp_thread_trace * tr = ctx ? &ctx->trace[HTP_MAX_NTHREADS] : NULL; - htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - core_dot_chunk_fp16(vtcm_output, vtcm_activation, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HMX_FP16_TILE_N_ROWS); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, HTP_MAX_NTHREADS); - } - - transfer_output_chunk_scattered_threaded( - ctx, dst, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, - matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, num_threads); - } - } - - HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); - return 0; -} diff --git a/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h new file mode 100644 index 0000000000..b7fba22a87 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hmx-mm-kernels-tiled.h @@ -0,0 +1,1306 @@ +#include "hmx-utils.h" +#include "hmx-queue.h" + +// MXFP4 dequantization LUT: maps 4-bit index to fp16 mantissa value +// kvalues: 0, 0.5, 1, 1.5, 2, 3, 4, 6, 0, -0.5, -1, -1.5, -2, -3, -4, -6 +static const __fp16 mxfp4_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + 0, 0, 0.5, 0, 1, 0, 1.5, 0, 2, 0, 3, 0, 4, 0, 6, 0, 0, 0, -0.5, 0, -1, 0, -1.5, 0, -2, 0, -3, 0, -4, 0, -6, 0, +}; + +static const __fp16 iq4_nl_to_fp16_lut[64] __attribute__((aligned(VLEN))) = { + -127, 0, -104, 0, -83, 0, -65, 0, -49, 0, -35, 0, -22, 0, -10, 0, + 1, 0, 13, 0, 25, 0, 38, 0, 53, 0, 69, 0, 89, 0, 113, 0, +}; + +// --- tiled format dequantizers --- + +typedef struct { + struct htp_context * ctx; + struct htp_thread_trace * traces; + __fp16 * dst; + const uint8_t * src; + + struct fastdiv_values n_k_tiles_div; + uint32_t n_k_tiles; + uint32_t n_tot_tiles; + uint32_t n_tiles_per_task; + uint32_t tile_size; + uint32_t aligned_tile_size; + uint32_t n_tasks; + uint32_t n_cols; + uint32_t k_block; + size_t row_stride; + uint32_t weight_type; +} tiled_dequantize_state_t; + +// Dequantize a single tile from tiled weight data (already in VTCM) to tile-major FP16. +static void dequantize_tiled_weight_to_fp16_task_q4_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Offsetting (-8) + v_lo0 = Q6_Vb_vsub_VbVb(v_lo0, i8); + v_hi0 = Q6_Vb_vsub_VbVb(v_hi0, i8); + v_lo1 = Q6_Vb_vsub_VbVb(v_lo1, i8); + v_hi1 = Q6_Vb_vsub_VbVb(v_hi1, i8); + v_lo2 = Q6_Vb_vsub_VbVb(v_lo2, i8); + v_hi2 = Q6_Vb_vsub_VbVb(v_hi2, i8); + v_lo3 = Q6_Vb_vsub_VbVb(v_lo3, i8); + v_hi3 = Q6_Vb_vsub_VbVb(v_hi3, i8); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q4_1( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector vscale_offset = hvx_vmem(tile_src + 512); + HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(vscale_offset, vscale_offset, -2); + HVX_Vector vd = Q6_V_lo_W(dm_deal); + HVX_Vector vm = Q6_V_hi_W(dm_deal); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vd, vd, -2)); + HVX_Vector v_offset_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vm, vm, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Unpack to 16-bit + HVX_VectorPair vp_int16_lo0 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_VectorPair vp_int16_hi0 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_VectorPair vp_int16_lo1 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_VectorPair vp_int16_hi1 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_VectorPair vp_int16_lo2 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_VectorPair vp_int16_hi2 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_VectorPair vp_int16_lo3 = Q6_Wh_vunpack_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_VectorPair vp_int16_hi3 = Q6_Wh_vunpack_Vb(Q6_V_hi_W(vp_shuf3)); + + // Convert, multiply, add offset + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi0)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi1)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi2)), v_scale_duplicated), v_offset_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_lo3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vadd_Vqf16Vhf(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_hi3)), v_scale_duplicated), v_offset_duplicated)); + + // Parallel Stores + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_iq4_nl( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(iq4_nl_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 512); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_mxfp4( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector vlut_cvt = hvx_vmem(mxfp4_to_fp16_lut); + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v = hvx_vmem(tile_src + 512); + HVX_Vector vh = Q6_V_lo_W(Q6_Wuh_vunpack_Vub(v)); + vh = Q6_Vh_vsub_VhVh(vh, Q6_Vh_vsplat_R(112)); + vh = Q6_Vh_vmax_VhVh(vh, Q6_V_vzero()); + vh = Q6_Vh_vmin_VhVh(vh, Q6_Vh_vsplat_R(30)); + vh = Q6_Vh_vasl_VhR(vh, 10); + + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(vh, vh, -2)); + + // Load all 4 groups in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + // Nibble extraction + HVX_Vector v_lo0 = Q6_V_vand_VV(vq0, mask_h4); + HVX_Vector v_hi0 = Q6_Vub_vlsr_VubR(vq0, 4); + HVX_Vector v_lo1 = Q6_V_vand_VV(vq1, mask_h4); + HVX_Vector v_hi1 = Q6_Vub_vlsr_VubR(vq1, 4); + HVX_Vector v_lo2 = Q6_V_vand_VV(vq2, mask_h4); + HVX_Vector v_hi2 = Q6_Vub_vlsr_VubR(vq2, 4); + HVX_Vector v_lo3 = Q6_V_vand_VV(vq3, mask_h4); + HVX_Vector v_hi3 = Q6_Vub_vlsr_VubR(vq3, 4); + + // Shuffling + HVX_VectorPair vp_shuf0 = Q6_W_vshuff_VVR(v_hi0, v_lo0, -1); + HVX_VectorPair vp_shuf1 = Q6_W_vshuff_VVR(v_hi1, v_lo1, -1); + HVX_VectorPair vp_shuf2 = Q6_W_vshuff_VVR(v_hi2, v_lo2, -1); + HVX_VectorPair vp_shuf3 = Q6_W_vshuff_VVR(v_hi3, v_lo3, -1); + + // Shuffle for LUT lookup + HVX_Vector v_q_lo0 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf0)); + HVX_Vector v_q_hi0 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf0)); + HVX_Vector v_q_lo1 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf1)); + HVX_Vector v_q_hi1 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf1)); + HVX_Vector v_q_lo2 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf2)); + HVX_Vector v_q_hi2 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf2)); + HVX_Vector v_q_lo3 = Q6_Vb_vshuff_Vb(Q6_V_lo_W(vp_shuf3)); + HVX_Vector v_q_hi3 = Q6_Vb_vshuff_Vb(Q6_V_hi_W(vp_shuf3)); + + // LUT lookup + HVX_VectorPair vp_lo0 = Q6_Wh_vlut16_VbVhR(v_q_lo0, vlut_cvt, 0); + HVX_VectorPair vp_hi0 = Q6_Wh_vlut16_VbVhR(v_q_hi0, vlut_cvt, 0); + HVX_VectorPair vp_lo1 = Q6_Wh_vlut16_VbVhR(v_q_lo1, vlut_cvt, 0); + HVX_VectorPair vp_hi1 = Q6_Wh_vlut16_VbVhR(v_q_hi1, vlut_cvt, 0); + HVX_VectorPair vp_lo2 = Q6_Wh_vlut16_VbVhR(v_q_lo2, vlut_cvt, 0); + HVX_VectorPair vp_hi2 = Q6_Wh_vlut16_VbVhR(v_q_hi2, vlut_cvt, 0); + HVX_VectorPair vp_lo3 = Q6_Wh_vlut16_VbVhR(v_q_lo3, vlut_cvt, 0); + HVX_VectorPair vp_hi3 = Q6_Wh_vlut16_VbVhR(v_q_hi3, vlut_cvt, 0); + + // Convert and scale multiplication + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo0), v_scale_duplicated)); + HVX_Vector v_grp0_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi0), v_scale_duplicated)); + HVX_Vector v_grp0_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi0), v_scale_duplicated)); + + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo1), v_scale_duplicated)); + HVX_Vector v_grp1_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi1), v_scale_duplicated)); + HVX_Vector v_grp1_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi1), v_scale_duplicated)); + + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo2), v_scale_duplicated)); + HVX_Vector v_grp2_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi2), v_scale_duplicated)); + HVX_Vector v_grp2_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi2), v_scale_duplicated)); + + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_lo3), v_scale_duplicated)); + HVX_Vector v_grp3_2 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_lo_W(vp_hi3), v_scale_duplicated)); + HVX_Vector v_grp3_3 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_V_hi_W(vp_hi3), v_scale_duplicated)); + + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp0_2; + hvx_vmem(dst_ptr + 3 * 64) = v_grp0_3; + + hvx_vmem(dst_ptr + 4 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp1_2; + hvx_vmem(dst_ptr + 7 * 64) = v_grp1_3; + + hvx_vmem(dst_ptr + 8 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp2_2; + hvx_vmem(dst_ptr + 11 * 64) = v_grp2_3; + + hvx_vmem(dst_ptr + 12 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp3_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp3_2; + hvx_vmem(dst_ptr + 15 * 64) = v_grp3_3; + } +} + +static void dequantize_tiled_weight_to_fp16_task_q8_0( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + for (uint32_t t = start_tile; t < end_tile; t++) { + const uint8_t * tile_src = state->src + t * state->aligned_tile_size; + __fp16 * dst_ptr = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + + HVX_Vector v_sc = hvx_vmem(tile_src + 1024); + HVX_Vector v_scale_duplicated = Q6_V_lo_W(Q6_W_vshuff_VVR(v_sc, v_sc, -2)); + + // Load groups 0-3 in parallel + HVX_Vector vq0 = hvx_vmem(tile_src + 0 * 128); + HVX_Vector vq1 = hvx_vmem(tile_src + 1 * 128); + HVX_Vector vq2 = hvx_vmem(tile_src + 2 * 128); + HVX_Vector vq3 = hvx_vmem(tile_src + 3 * 128); + + HVX_VectorPair vp_int16_0 = Q6_Wh_vunpack_Vb(vq0); + HVX_VectorPair vp_int16_1 = Q6_Wh_vunpack_Vb(vq1); + HVX_VectorPair vp_int16_2 = Q6_Wh_vunpack_Vb(vq2); + HVX_VectorPair vp_int16_3 = Q6_Wh_vunpack_Vb(vq3); + + // Load groups 4-7 in parallel + HVX_Vector vq4 = hvx_vmem(tile_src + 4 * 128); + HVX_Vector vq5 = hvx_vmem(tile_src + 5 * 128); + HVX_Vector vq6 = hvx_vmem(tile_src + 6 * 128); + HVX_Vector vq7 = hvx_vmem(tile_src + 7 * 128); + + HVX_VectorPair vp_int16_4 = Q6_Wh_vunpack_Vb(vq4); + HVX_VectorPair vp_int16_5 = Q6_Wh_vunpack_Vb(vq5); + HVX_VectorPair vp_int16_6 = Q6_Wh_vunpack_Vb(vq6); + HVX_VectorPair vp_int16_7 = Q6_Wh_vunpack_Vb(vq7); + + // Convert and scale multiply for groups 0-3 + HVX_Vector v_grp0_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp0_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_0)), v_scale_duplicated)); + HVX_Vector v_grp1_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp1_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_1)), v_scale_duplicated)); + HVX_Vector v_grp2_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp2_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_2)), v_scale_duplicated)); + HVX_Vector v_grp3_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_3)), v_scale_duplicated)); + HVX_Vector v_grp3_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_3)), v_scale_duplicated)); + + // Store groups 0-3 + hvx_vmem(dst_ptr + 0 * 64) = v_grp0_0; + hvx_vmem(dst_ptr + 1 * 64) = v_grp0_1; + hvx_vmem(dst_ptr + 2 * 64) = v_grp1_0; + hvx_vmem(dst_ptr + 3 * 64) = v_grp1_1; + hvx_vmem(dst_ptr + 4 * 64) = v_grp2_0; + hvx_vmem(dst_ptr + 5 * 64) = v_grp2_1; + hvx_vmem(dst_ptr + 6 * 64) = v_grp3_0; + hvx_vmem(dst_ptr + 7 * 64) = v_grp3_1; + + // Convert and scale multiply for groups 4-7 + HVX_Vector v_grp4_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp4_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_4)), v_scale_duplicated)); + HVX_Vector v_grp5_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp5_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_5)), v_scale_duplicated)); + HVX_Vector v_grp6_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp6_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_6)), v_scale_duplicated)); + HVX_Vector v_grp7_0 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_lo_W(vp_int16_7)), v_scale_duplicated)); + HVX_Vector v_grp7_1 = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(Q6_Vhf_equals_Vh(Q6_V_hi_W(vp_int16_7)), v_scale_duplicated)); + + // Store groups 4-7 + hvx_vmem(dst_ptr + 8 * 64) = v_grp4_0; + hvx_vmem(dst_ptr + 9 * 64) = v_grp4_1; + hvx_vmem(dst_ptr + 10 * 64) = v_grp5_0; + hvx_vmem(dst_ptr + 11 * 64) = v_grp5_1; + hvx_vmem(dst_ptr + 12 * 64) = v_grp6_0; + hvx_vmem(dst_ptr + 13 * 64) = v_grp6_1; + hvx_vmem(dst_ptr + 14 * 64) = v_grp7_0; + hvx_vmem(dst_ptr + 15 * 64) = v_grp7_1; + } +} + +static void convert_f16_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(__fp16); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0 = hvx_vmemu((const __fp16 *)(r0 + byte_off)); + HVX_Vector v1 = (row1 < state->n_cols) ? hvx_vmemu((const __fp16 *)(r1 + byte_off)) : Q6_V_vzero(); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v0); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v1); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +static void quantize_f32_weight_to_fp16_tiles_task( + const tiled_dequantize_state_t *state, + uint32_t start_tile, uint32_t end_tile) { + + const uint32_t n_k_tiles = state->n_k_tiles; + const struct fastdiv_values n_k_tiles_div = state->n_k_tiles_div; + + const HVX_Vector v_scat_base = hvx_vmem(hmx_transpose_scatter_offsets); + const HVX_Vector v_scat_step = Q6_V_vsplat_R(4); + const HVX_VectorPred q_mask64 = Q6_Q_vsetq_R(64); + + unsigned ct = fastdiv((unsigned)start_tile, &n_k_tiles_div); + unsigned kt = fastmodulo((unsigned)start_tile, n_k_tiles, &n_k_tiles_div); + + for (unsigned t = start_tile; t < (unsigned)end_tile; ) { + if (kt >= (unsigned)n_k_tiles) { kt = 0; ct++; } + + __fp16 *tile_base = state->dst + t * HTP_MM_HMX_TILE_N_ELMS; + { + uint32_t byte_off = kt * 32 * sizeof(float); + + HVX_Vector v_off = v_scat_base; + for (uint32_t r = 0; r < HTP_MM_HMX_TILE_N_ROWS; r += 2) { + uint32_t row0 = ct * HTP_MM_HMX_TILE_N_COLS + r; + uint32_t row1 = row0 + 1; + + const uint8_t *r0 = state->src + row0 * state->row_stride; + const uint8_t *r1 = state->src + row1 * state->row_stride; + + HVX_Vector v0_f32 = hvx_vmem((const float *)(r0 + byte_off)); + HVX_Vector v1_f32 = (row1 < state->n_cols) ? hvx_vmem((const float *)(r1 + byte_off)) : Q6_V_vzero(); + + HVX_Vector v_out = hvx_vec_f32_to_f16(v0_f32, v1_f32); + + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + + HVX_Vector v_out_hi = Q6_V_vror_VR(v_out, 64); + Q6_vscatter_QRMVwV(q_mask64, (size_t)tile_base, HTP_MM_HMX_TILE_SIZE - 1, v_off, v_out_hi); + v_off = Q6_Vw_vadd_VwVw(v_off, v_scat_step); + } + (void) *(volatile HVX_Vector *)(tile_base); + } + ++t; ++kt; + } + + if (start_tile < end_tile) { + (void) *(volatile HVX_Vector *)(state->dst + (end_tile - 1) * HTP_MM_HMX_TILE_N_ELMS); + } +} + +// --- End tiled dequantizers --- + +// requires external HMX lock +static void core_dot_chunk_fp16(__fp16 *restrict output, const __fp16 *restrict activation, const __fp16 *restrict weight, const __fp16 *restrict scales, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)scales); + for (uint32_t r = 0; r < n_row_tiles; ++r) { + for (size_t c = 0; c < n_col_tiles; ++c) { + Q6_mxclracc_hf(); + + const __fp16 *row_tiles = activation + r * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + const __fp16 *col_tiles = weight + c * n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * (uint32_t)k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + __fp16 *out_tile = output + (r * n_col_tiles + c) * HTP_MM_HMX_TILE_N_ELMS; + Q6_mxmem_AR_after_hf(out_tile, 0); + } + } +} + +// C += AB +static void core_mma_chunk_fp16(__fp16 *restrict c, const __fp16 *restrict a, const __fp16 *restrict b, + const __fp16 *restrict col_scales, const __fp16 *restrict eye_tile, + uint32_t n_row_tiles, uint32_t n_col_tiles, uint32_t n_dot_tiles, bool zero_init) { + __builtin_assume(n_row_tiles > 0); + __builtin_assume(n_col_tiles > 0); + __builtin_assume(n_dot_tiles > 0); + + Q6_bias_mxmem2_A((void *)col_scales); + + const size_t dot_tile_stride = n_dot_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t i = 0; i < n_row_tiles; ++i) { + const __fp16 *row_base = a + i * dot_tile_stride; + __fp16 *res_base = c + i * n_col_tiles * HTP_MM_HMX_TILE_N_ELMS; + for (size_t j = 0; j < n_col_tiles; ++j) { + Q6_mxclracc_hf(); + + const __fp16 *col_tiles = b + j * dot_tile_stride; + const __fp16 *row_tiles = row_base; + __fp16 *accum_tile = res_base + j * HTP_MM_HMX_TILE_N_ELMS; + if (!zero_init) { + Q6_activation_hf_mxmem_RR((unsigned int)accum_tile, 2047); + Q6_weight_hf_mxmem_RR((unsigned int)eye_tile, 2047); + } + + for (uint32_t k = 0, k_block; k < n_dot_tiles; k += k_block) { + k_block = hex_smin(n_dot_tiles - k, 32); + const uint32_t range = 2048u * k_block - 1; + Q6_activation_hf_mxmem_RR_deep((unsigned int)row_tiles, range); + Q6_weight_hf_mxmem_RR((unsigned int)col_tiles, range); + row_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + col_tiles += k_block * HTP_MM_HMX_TILE_N_ELMS; + } + + Q6_mxmem_AR_after_hf(accum_tile, 0); + } + } +} + +// --- Async HMX matmul job (for pipeline overlap) --- + +typedef struct { + __fp16 * output; + const __fp16 * activation; + const __fp16 * weight; + const __fp16 * scales; + uint32_t n_row_tiles; + uint32_t n_col_tiles; + uint32_t n_dot_tiles; +} hmx_matmul_job_t; + +static void hmx_matmul_worker_fn(void * data) { + hmx_matmul_job_t * job = (hmx_matmul_job_t *) data; + FARF(HIGH, "hmx-mm-job: n_row_tiles %u n_col_tiles %u n_dot_tiles %u", job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); + core_dot_chunk_fp16(job->output, job->activation, job->weight, job->scales, job->n_row_tiles, job->n_col_tiles, job->n_dot_tiles); +} + +static inline void hmx_matmul_job_init(hmx_matmul_job_t * job, + __fp16 * output, + const __fp16 * activation, + const __fp16 * weight, + const __fp16 * scales, + uint32_t n_row_tiles, + uint32_t n_col_tiles, + uint32_t n_dot_tiles) { + job->output = output; + job->activation = activation; + job->weight = weight; + job->scales = scales; + job->n_row_tiles = n_row_tiles; + job->n_col_tiles = n_col_tiles; + job->n_dot_tiles = n_dot_tiles; +} + +// output : fp16 -> f32p + +static void transfer_output_chunk_fp16_to_fp32(float *restrict dst, const __fp16 *restrict vtcm_src, uint32_t start_row, uint32_t n_rows, uint32_t n_cols, uint32_t dst_stride, uint32_t dst_cols) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + const size_t limit_c = hex_smin(n_cols, dst_cols); + const size_t limit_c_aligned = (limit_c & ~31); + + for (size_t r = 0; r < n_rows; r += 2) { + const size_t r_idx0 = start_row + r + 0; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + float *output_row_base = dst + r * dst_stride; // global memory row base for row r (and r+1) + + #pragma unroll(4) + for (size_t c = 0; c < limit_c_aligned; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row_base + c + 0); + HVX_Vector *pv_out1 = (HVX_Vector *) (output_row_base + c + dst_stride); + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (r + 1 < n_rows) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + + if (limit_c_aligned < limit_c) { + size_t c = limit_c_aligned; + size_t valid_c = limit_c - c; + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + hvx_vec_store_u(output_row_base + c, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp))); + if (r + 1 < n_rows) { + hvx_vec_store_u(output_row_base + c + dst_stride, valid_c * sizeof(float), Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp))); + } + } + } +} + +typedef struct { + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t dst_stride; // DDR row stride + uint32_t dst_cols; // Actual output columns + struct htp_thread_trace * traces; +} output_transfer_task_state_t; + +// activations : fp32 -> fp16 + +static void transfer_activation_chunk_fp32_to_fp16(__fp16 *restrict vtcm_dst, const float *restrict src, uint32_t n_rows, uint32_t k_block, uint32_t k_stride, uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const float *ptr_in0 = src + (r + 0) * k_stride; + const float *ptr_in1 = src + (r + 1) * k_stride; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(ptr_in0 + c); + HVX_Vector v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r0 = r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = r < n_rows; + const bool row1_valid = (r + 1) < n_rows; + + const float *ptr_in0 = row0_valid ? (src + (r + 0) * k_stride) : NULL; + const float *ptr_in1 = row1_valid ? (src + (r + 1) * k_stride) : NULL; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +typedef struct { + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t k_stride; + uint32_t k_valid; + struct htp_thread_trace * traces; + struct htp_context * ctx; + float * vtcm_f32_act; +} activation_transfer_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + dma_queue *dma_q, + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t n_rows, + uint32_t k_block, + uint32_t k_stride, + uint32_t k_valid, + float *thread_f32_act) { + + const uint32_t R = HTP_MM_DMA_ACT_ROWS_PER_STEP; + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + + const uint32_t n_steps = n_rows_padded / R; + + // pre-fetch step 0 + if (n_steps > 0 && n_rows > 0) { + uint32_t nrows_to_fetch = hex_smin(n_rows, R); + dma_queue_push(dma_q, dma_make_ptr(thread_f32_act, src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + for (uint32_t s = 0; s < n_steps; ++s) { + uint32_t r = R * s; + float *curr_buf = thread_f32_act + (s % 2) * R * k_block; + + if (r < n_rows) { + dma_queue_pop(dma_q); + } + + uint32_t next_s = s + 1; + uint32_t next_r = R * next_s; + if (next_r < n_rows) { + uint32_t nrows_to_fetch = hex_smin(n_rows - next_r, R); + const float *next_src = src + next_r * k_stride; + float *next_buf = thread_f32_act + (next_s % 2) * R * k_block; + dma_queue_push(dma_q, dma_make_ptr(next_buf, next_src), + k_block * sizeof(float), k_stride * sizeof(float), k_valid * sizeof(float), nrows_to_fetch); + } + + #pragma unroll + for (uint32_t i = 0; i < HTP_MM_DMA_ACT_ROWS_PER_STEP; i += 2) { + uint32_t curr_r = r + i; + const bool row0_valid = (curr_r < n_rows); + const bool row1_valid = (curr_r + 1) < n_rows; + + const float *ptr_in0 = curr_buf + i * k_block; + const float *ptr_in1 = curr_buf + (i + 1) * k_block; + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(ptr_in0 + c); + if (row1_valid) v1 = *(const HVX_Vector *)(ptr_in1 + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t r0 = curr_r / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = curr_r % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; // tile column index + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + } +} + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + __fp16 *dst; + const float *src; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t k_block; + uint32_t cur_a; + uint32_t mapping_stride; + uint32_t ne11; + struct fastdiv_values ne11_div; + size_t nb11; + size_t nb12; + uint32_t start_row; + uint32_t cne1; + uint32_t k_valid; + struct htp_thread_trace *traces; +} activation_transfer_gathered_task_state_t; + +typedef struct { + const struct mmid_row_mapping *matrix_rows; + const __fp16 *vtcm_src; + float *dst; + uint32_t n_tasks; + uint32_t n_tot_chunks; + uint32_t n_chunks_per_task; + uint32_t n_cols; + uint32_t cur_a; + uint32_t mapping_stride; + size_t dst_nb1; + size_t dst_nb2; + uint32_t start_row; + uint32_t cne1; + struct htp_thread_trace *traces; +} output_transfer_scattered_task_state_t; + +static void transfer_activation_chunk_fp32_to_fp16_gathered( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + uint32_t ne11, + const struct fastdiv_values * ne11_div, + size_t nb11, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + + const float *row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + uint32_t i11_0 = fastmodulo(mapping0.i1, ne11, ne11_div); + row0_ptr = (const float *) ((const uint8_t *) src + i11_0 * nb11 + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + uint32_t i11_1 = fastmodulo(mapping1.i1, ne11, ne11_div); + row1_ptr = (const float *) ((const uint8_t *) src + i11_1 * nb11 + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_activation_chunk_fp32_to_fp16_gathered_flat( + __fp16 *restrict vtcm_dst, + const float *restrict src, + uint32_t start_row, + uint32_t n_rows, + uint32_t k_block, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t nb12, + uint32_t cne1, + uint32_t k_valid) { + const uint32_t n_rows_padded = hex_align_up(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const uint32_t n_rows_tiled = (n_rows / HTP_MM_HMX_TILE_N_ROWS) * HTP_MM_HMX_TILE_N_ROWS; + + uint32_t r = 0; + + #pragma unroll(2) + for (r = 0; r < n_rows_tiled; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + + const float *row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + const float *row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = *(const HVX_Vector *)(row0_ptr + c); + HVX_Vector v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } + + for (; r < n_rows_padded; r += 2) { + uint32_t r_idx0 = start_row + r; + uint32_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; // tile row index + uint32_t r1 = r_idx0 % HTP_MM_HMX_TILE_N_ROWS; // intra-tile row idx + + const bool row0_valid = (start_row + r + 0) < cne1; + const bool row1_valid = (start_row + r + 1) < cne1; + + const float *row0_ptr = NULL; + const float *row1_ptr = NULL; + + if (row0_valid) { + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + (start_row + r + 0)]; + row0_ptr = (const float *) ((const uint8_t *) src + mapping0.i2 * nb12); + } + if (row1_valid) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + (start_row + r + 1)]; + row1_ptr = (const float *) ((const uint8_t *) src + mapping1.i2 * nb12); + } + + uint32_t c = 0; + for (; c + 32 <= k_valid; c += 32) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + if (c < k_block) { + HVX_Vector v0 = Q6_V_vzero(); + HVX_Vector v1 = Q6_V_vzero(); + if (row0_valid) v0 = *(const HVX_Vector *)(row0_ptr + c); + if (row1_valid) v1 = *(const HVX_Vector *)(row1_ptr + c); + + uint32_t rem = k_valid - c; + HVX_VectorPred mask = Q6_Q_vsetq2_R(rem > 0 ? rem * sizeof(float) : 0); + v0 = Q6_V_vmux_QVV(mask, v0, Q6_V_vzero()); + v1 = Q6_V_vmux_QVV(mask, v1, Q6_V_vzero()); + + HVX_Vector v_out = hvx_vec_f32_to_f16_shuff(v0, v1); + + uint32_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + uint32_t tile_idx = r0 * (k_block / HTP_MM_HMX_TILE_N_COLS) + c0; + + HVX_Vector *tile = (HVX_Vector *) (vtcm_dst + tile_idx * HTP_MM_HMX_TILE_N_ELMS); + tile[r1 / 2] = v_out; + } + } +} + +static void transfer_output_chunk_fp16_to_fp32_scattered( + float *restrict dst, + const __fp16 *restrict vtcm_src, + uint32_t start_row, + uint32_t n_rows, + uint32_t n_cols, + const struct mmid_row_mapping *matrix_rows, + uint32_t cur_a, + uint32_t mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + uint32_t cne1) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + const size_t tile_row_stride = (n_cols / HTP_MM_HMX_TILE_N_COLS) * HTP_MM_HMX_TILE_N_ELMS; + + const HVX_Vector one = hvx_vec_splat_f16(1.0); + + for (size_t r = 0; r < n_rows; r += 2) { + uint32_t r_idx0 = start_row + r + 0; + uint32_t r_idx1 = start_row + r + 1; + const size_t r0 = r_idx0 / HTP_MM_HMX_TILE_N_ROWS; + const size_t r1 = (r_idx0 % HTP_MM_HMX_TILE_N_ROWS) / 2; // index of the row pair within the tile + const __fp16 *row_base = vtcm_src + r0 * tile_row_stride; + + if (r_idx0 >= cne1) break; + + struct mmid_row_mapping mapping0 = matrix_rows[cur_a * mapping_stride + r_idx0]; + float *output_row0 = (float *) ((uint8_t *) dst + mapping0.i1 * dst_nb1 + mapping0.i2 * dst_nb2); + + float *output_row1 = NULL; + if (r_idx1 < cne1) { + struct mmid_row_mapping mapping1 = matrix_rows[cur_a * mapping_stride + r_idx1]; + output_row1 = (float *) ((uint8_t *) dst + mapping1.i1 * dst_nb1 + mapping1.i2 * dst_nb2); + } + + #pragma unroll(4) + for (size_t c = 0; c < (size_t)n_cols; c += HTP_MM_HMX_TILE_N_COLS) { + const size_t c0 = c / HTP_MM_HMX_TILE_N_COLS; + const __fp16 *tile = row_base + c0 * HTP_MM_HMX_TILE_N_ELMS; + HVX_Vector v = ((const HVX_Vector *) tile)[r1]; + HVX_VectorPair vp = Q6_Wqf32_vmpy_VhfVhf(v, one); + + HVX_Vector *pv_out0 = (HVX_Vector *) (output_row0 + c); + HVX_Vector *pv_out1 = output_row1 ? (HVX_Vector *) (output_row1 + c) : NULL; + + *pv_out0 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(vp)); + if (pv_out1) { + *pv_out1 = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(vp)); + } + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.c b/ggml/src/ggml-hexagon/htp/hmx-ops.c deleted file mode 100644 index 114d8c1481..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.c +++ /dev/null @@ -1,6 +0,0 @@ -// HMX operations compiled as a single translation unit. -// This allows interprocedural optimizations within HMX ops without requiring global HTP LTO. - -#include "hmx-queue.c" -#include "hmx-matmul-ops.c" -#include "hmx-flash-attn-ops.c" diff --git a/ggml/src/ggml-hexagon/htp/hmx-ops.h b/ggml/src/ggml-hexagon/htp/hmx-ops.h deleted file mode 100644 index a67842f3ff..0000000000 --- a/ggml/src/ggml-hexagon/htp/hmx-ops.h +++ /dev/null @@ -1,88 +0,0 @@ -// HMX operation entry-point declarations. -// Ported from htp-ops-lib/include/dsp/ops.h (renamed, benchmark kernels removed). (https://github.com/haozixu/htp-ops-lib) - -#ifndef HMX_OPS_H -#define HMX_OPS_H - -#include -#include - -#include "htp-ops.h" - -#ifdef __cplusplus -extern "C" { -#endif - -typedef struct { - float *dst; - const float *activation; - const __fp16 *permuted_weight; - int m; - int k; - int n; - int act_stride; - int weight_stride; - int dst_stride; - int ne02; - int ne03; - int ne12; - int ne13; - size_t src0_nb2; - size_t src0_nb3; - size_t src1_nb2; - size_t src1_nb3; - size_t dst_nb2; - size_t dst_nb3; -} hmx_matmul_f16_f32_batched_params_t; - -// HMX matrix multiplication โ€” tile-permuted FP16 weights, FP32 activation/output -// act_stride: activation row stride in elements (= k for contiguous, or -// nb[1]/sizeof(float) for permuted tensors like attention Q). -// weight_stride: weight row stride in elements (= k for compact weights, or -// nb[1]/sizeof(__fp16) for permuted KV-cache views used by QK). -int hmx_matmul_f16_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const __fp16 *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride); - -// Batched F16 wrapper over hmx_mat_mul_f16_f32. -// Batch semantics match ggml_mul_mat(): src0 broadcasts to src1 in dims 2/3. -int hmx_matmul_f16_f32_batched(struct htp_context *ctx, const hmx_matmul_f16_f32_batched_params_t *params); - -// HMX matrix multiplication โ€” all supported weight types (F16/F32/Q4_0/Q4_1/Q8_0/IQ4_NL/MXFP4) -int hmx_matmul_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int act_stride, - int weight_stride, - int weight_type); - -struct mmid_row_mapping; - -int hmx_matmul_id_2d_f32(struct htp_context *ctx, - float *restrict dst, - const float *activation, - const uint8_t *permuted_weight, - int m, int k, int n, - int ne11, - size_t act_nb1, size_t act_nb2, - size_t dst_nb1, size_t dst_nb2, - int weight_stride, - int weight_type, - const struct mmid_row_mapping *matrix_rows, - int cur_a, - int mapping_stride); - -// HMX flash attention -int hmx_flash_attn_ext(struct htp_ops_context * octx); - -#ifdef __cplusplus -} -#endif - -#endif // HMX_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index cbb5d08786..6ad77d3daa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -13,7 +13,9 @@ #include #include +#ifndef HTP_MAX_NTHREADS #define HTP_MAX_NTHREADS 10 +#endif #define HTP_MAX_MMAPS 16 // Memory mapping @@ -42,9 +44,13 @@ struct htp_ops_context { enum htp_op_code op; // FIXME: rename to opcode int32_t op_params[HTP_OP_MAX_PARAMS]; + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; const struct htp_tensor * src[HTP_OP_MAX_INPUTS]; - const struct htp_tensor * dst; + union { + const struct htp_tensor * dst; + const struct htp_tensor * dsts[HTP_OP_MAX_OUTPUTS]; + }; // TODO convert these to an array struct htp_spad src0_spad; @@ -87,13 +93,13 @@ struct htp_context { struct htp_ops_context octx; -#ifdef HTP_HAS_HMX struct hmx_queue * hmx_queue; // Async HMX queue for pipeline overlap -#endif }; int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); +int op_matmul_qkv(struct htp_ops_context * octx); +int op_matmul_ffn(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); int op_sum_rows(struct htp_ops_context * octx); diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 0f4b74a93a..d040901357 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -28,18 +28,19 @@ enum htp_data_type { HTP_TYPE_MXFP4 = 39, // types used internally for repack, dyn.quant, etc - HTP_TYPE_Q4_0x4x2 = 200, - HTP_TYPE_Q4_1x4x2, - HTP_TYPE_Q8_0x4x2, - HTP_TYPE_MXFP4x4x2, + HTP_TYPE_Q4_0_TILED = 200, + HTP_TYPE_Q4_1_TILED, + HTP_TYPE_Q8_0_TILED, + HTP_TYPE_MXFP4_TILED, HTP_TYPE_INVALID }; // Constats for internal types -#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) -#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks -#define QK_MXFP4x4x2 256 // 4x MXFP4 blocks concat with next 4x MXFP4 blocks +#define QK_Q4_0_TILED 256 // 32x32 Q4_0 tiled layout +#define QK_Q8_0_TILED 128 // 32x32 Q8_0 tiled layout +#define QK_MXFP4_TILED 256 // 32x32 MXFP4 tiled layout + // Mask to enable various stages of the Ops. @@ -57,6 +58,8 @@ enum htp_op_code { HTP_OP_DIV = 3, HTP_OP_MUL_MAT, HTP_OP_MUL_MAT_ID, + HTP_OP_MUL_MAT_QKV, + HTP_OP_MUL_MAT_FFN, HTP_OP_RMS_NORM, HTP_OP_RMS_NORM_MUL, HTP_OP_UNARY_SILU, @@ -99,7 +102,9 @@ enum htp_op_code { #define HTP_OP_MAX_DIMS 4 // aka GGML_MAX_DIMS #define HTP_OP_MAX_INPUTS 6 // aka GGML_MAX_SRCS +#define HTP_OP_MAX_OUTPUTS 4 #define HTP_OP_MAX_PARAMS 16 // aka GGML_MAX_OP_PARAMS +#define HTP_OP_MAX_KERN_PARAMS 32 #define HTP_OP_MAX_BUFS 16 #define HTP_OP_MAX_REQS 256 @@ -142,8 +147,10 @@ struct htp_op_desc { uint32_t opcode; // GGML/HTP Op uint32_t flags; // Op flags int32_t params[HTP_OP_MAX_PARAMS]; // Params for the op, e.g. epsilon of RMS norm + int32_t kernel_params[HTP_OP_MAX_KERN_PARAMS]; // generic blob for host-precomputed parameters uint16_t src[HTP_OP_MAX_INPUTS]; // Input tensors indices - uint16_t dst; // Output tensor index + uint16_t dst[HTP_OP_MAX_OUTPUTS]; // Output tensor indices + uint16_t pad[2]; // padding to align to 64 bits }; #ifndef HTP_MAX_NTHREADS diff --git a/ggml/src/ggml-hexagon/htp/htp_iface.idl b/ggml/src/ggml-hexagon/htp/htp_iface.idl index d696a5fba0..47693d8b8b 100644 --- a/ggml/src/ggml-hexagon/htp/htp_iface.idl +++ b/ggml/src/ggml-hexagon/htp/htp_iface.idl @@ -11,12 +11,13 @@ struct htp_iface_pmu_conf { }; interface htp_iface : remote_handle64 { - AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 use_hmx, in uint64 max_vmem); + AEEResult start(in uint32 sess_id, in uint64 dsp_queue_id, in uint32 n_hvx, in uint32 n_hmx, in uint64 max_vmem); AEEResult stop(); AEEResult mmap(in uint32 fd, in uint32 size); AEEResult munmap(in uint32 fd); AEEResult profiler(in uint32 mode, in htp_iface_pmu_conf pmu); AEEResult etm(in uint32 enable); + AEEResult hwinfo(rout uint32 n_threads, rout uint32 n_hvx, rout uint32 n_hmx, rout uint64 vtcm_size); }; #endif /* HTP_IDL */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index f6cb02951d..493b26c6e7 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -170,25 +170,7 @@ static inline HVX_VectorPair hvx_vec_f16_to_f32(HVX_Vector v) { } #endif -/* Q6_Vsf_equals_Vw is only available on v73+.*/ -#if __HVX_ARCH__ < 73 -static inline HVX_Vector hvx_vec_i32_to_qf32(HVX_Vector const in) -{ - HVX_Vector const vzero = Q6_V_vzero(); - HVX_VectorPred is_zero = Q6_Q_vcmp_eq_VwVw(in, vzero); - HVX_Vector lshift = Q6_Vw_vnormamt_Vw(in); - HVX_Vector normalized = Q6_Vw_vasl_VwVw(in, lshift); - HVX_Vector vexp = Q6_Vw_vsub_VwVw(Q6_V_vsplat_R(0x7f + 30), lshift); - HVX_Vector mant = Q6_V_vand_VV(Q6_V_vsplat_R(0xFFFFFF00), normalized); - HVX_Vector ret = Q6_V_vmux_QVV(is_zero, vzero, Q6_Vw_vadd_VwVw(mant, vexp)); - return ret; -} -static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) -{ - return Q6_Vsf_equals_Vqf32(hvx_vec_i32_to_qf32(in)); -} -#endif static inline HVX_Vector hvx_vec_i16_from_hf_rnd_sat(HVX_Vector vin) { // This looks complicated. @@ -305,4 +287,17 @@ static inline HVX_Vector hvx_vec_mul_f32_f32(HVX_Vector a, HVX_Vector b) { #endif // __HVX_ARCH__ < 79 +static inline HVX_Vector hvx_vec_load_act_tile(const uint8_t * y_q, uint32_t kt, HVX_Vector * v_act_all) { + if (kt % 4 == 0) { + *v_act_all = hvx_vmem(y_q + kt * 32); + return *v_act_all; + } else if (kt % 4 == 1) { + return Q6_V_vror_VR(*v_act_all, 32); + } else if (kt % 4 == 2) { + return Q6_V_vror_VR(*v_act_all, 64); + } else { + return Q6_V_vror_VR(*v_act_all, 96); + } +} + #endif /* HVX_BASE_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h new file mode 100644 index 0000000000..52351b1039 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-flat.h @@ -0,0 +1,1024 @@ +// Dynamic quantizers that produce flat (non-tiled) activations + +static inline void quantize_block_f32_q8_0_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + hvx_vec_store_u(y_scales + block_idx * 4, 8, v_scales); +} + +static inline void quantize_block_f32_q8_1_flat( + float * restrict x, + uint8_t * restrict y_quants, + __fp16 * restrict y_scales, + uint32_t block_idx +) { + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + * (HVX_Vector *) (y_quants + block_idx * 128) = vx_i8; + + HVX_VectorPair vp1 = Q6_W_vshuff_VVR(vd23_hf, vd01_hf, -2); + HVX_VectorPair vp2 = Q6_W_vshuff_VVR(Q6_V_hi_W(vp1), Q6_V_lo_W(vp1), -2); + HVX_Vector v_scales = Q6_V_lo_W(vp2); + + HVX_VectorPair v_deal1 = Q6_W_vdeal_VVR(v_sums, v_sums, -4); + HVX_Vector v_even1 = Q6_V_lo_W(v_deal1); + HVX_VectorPair v_deal2 = Q6_W_vdeal_VVR(v_even1, v_even1, -4); + HVX_Vector v_even2 = Q6_V_lo_W(v_deal2); + HVX_VectorPair v_deal3 = Q6_W_vdeal_VVR(v_even2, v_even2, -4); + HVX_Vector v_sums_shuffled = Q6_V_lo_W(v_deal3); + + HVX_Vector v_sums_sf = Q6_Vsf_equals_Vw(v_sums_shuffled); + HVX_Vector v_sums_hf = hvx_vec_f32_to_f16(v_sums_sf, Q6_V_vzero()); + + HVX_Vector v_prod = hvx_vec_mul_f16_f16(v_scales, v_sums_hf); + + HVX_VectorPair vp_scales = Q6_W_vshuff_VVR(v_prod, v_scales, -2); + HVX_Vector v_final = Q6_V_lo_W(vp_scales); + + hvx_vec_store_u(y_scales + block_idx * 8, 16, v_final); +} + +static inline void quantize_row_f32_q8_0_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_0_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_row_f32_q8_1_flat(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t quants_size = hex_round_up(k, 128); + uint8_t * restrict y_quants = y; + __fp16 * restrict y_scales = (__fp16 *) (y + quants_size); + + const uint32_t nb = (k + 127) / 128; + for (uint32_t i = 0; i < nb; i++) { + quantize_block_f32_q8_1_flat(x + i * 128, y_quants, y_scales, i); + } +} + +static inline void quantize_f32_q8_0_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_flat((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_f32_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f32_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_f32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +static inline void quantize_f16_f16_flat_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_stride, + size_t dst_stride +) { + (void) tmp_data; + const size_t src_row_size = ne0 * sizeof(float); + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_stride, 2); + hvx_copy_f16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } +} + +// Dot kernels that consume flat (non-tiled) activations + +static void flat_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act_rep, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a_val = y_scales[kt * 2 + 0]; + __fp16 sum_a_val = y_scales[kt * 2 + 1]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_Vector v_sum_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0_rep, v_act1_rep, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + __fp16 scale_a0_val = y0_scales[kt * 2 + 0]; + __fp16 sum_a0_val = y0_scales[kt * 2 + 1]; + __fp16 scale_a1_val = y1_scales[kt * 2 + 0]; + __fp16 sum_a1_val = y1_scales[kt * 2 + 1]; + + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_sum_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_Vector v_sum_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&sum_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx_i8 = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx_i8, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act_rep); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0_i8 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1_i8 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0_i8, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1_i8, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0_rep, v_act1_rep); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void flat_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y_scales = (const __fp16 *) (y_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx = * (const HVX_Vector *) (y_q + block_idx * 128); + HVX_Vector v_act_raw = Q6_V_vror_VR(vx, sub_idx * 32); + + HVX_Vector v_act_rep[8]; + v_act_rep[0] = Q6_V_vdelta_VV(v_act_raw, v_repl_ctrl); + v_act_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 4), v_repl_ctrl); + v_act_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 8), v_repl_ctrl); + v_act_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 12), v_repl_ctrl); + v_act_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 16), v_repl_ctrl); + v_act_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 20), v_repl_ctrl); + v_act_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 24), v_repl_ctrl); + v_act_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act_raw, 28), v_repl_ctrl); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act_rep, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a_val = y_scales[kt]; + HVX_Vector v_scale_a_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a_val)); + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void flat_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + const uint32_t quants_size = hex_round_up(n, 128); + const __fp16 * restrict y0_scales = (const __fp16 *) (y0_q + quants_size); + const __fp16 * restrict y1_scales = (const __fp16 *) (y1_q + quants_size); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + + uint32_t block_idx = kt / 4; + uint32_t sub_idx = kt % 4; + + HVX_Vector vx0 = * (const HVX_Vector *) (y0_q + block_idx * 128); + HVX_Vector vx1 = * (const HVX_Vector *) (y1_q + block_idx * 128); + + HVX_Vector v_act0_raw = Q6_V_vror_VR(vx0, sub_idx * 32); + HVX_Vector v_act1_raw = Q6_V_vror_VR(vx1, sub_idx * 32); + + HVX_Vector v_act0_rep[8]; + v_act0_rep[0] = Q6_V_vdelta_VV(v_act0_raw, v_repl_ctrl); + v_act0_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 4), v_repl_ctrl); + v_act0_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 8), v_repl_ctrl); + v_act0_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 12), v_repl_ctrl); + v_act0_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 16), v_repl_ctrl); + v_act0_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 20), v_repl_ctrl); + v_act0_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 24), v_repl_ctrl); + v_act0_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act0_raw, 28), v_repl_ctrl); + + HVX_Vector v_act1_rep[8]; + v_act1_rep[0] = Q6_V_vdelta_VV(v_act1_raw, v_repl_ctrl); + v_act1_rep[1] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 4), v_repl_ctrl); + v_act1_rep[2] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 8), v_repl_ctrl); + v_act1_rep[3] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 12), v_repl_ctrl); + v_act1_rep[4] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 16), v_repl_ctrl); + v_act1_rep[5] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 20), v_repl_ctrl); + v_act1_rep[6] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 24), v_repl_ctrl); + v_act1_rep[7] = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act1_raw, 28), v_repl_ctrl); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0_rep, v_act1_rep, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + __fp16 scale_a0_val = y0_scales[kt]; + __fp16 scale_a1_val = y1_scales[kt]; + HVX_Vector v_scale_a0_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a0_val)); + HVX_Vector v_scale_a1_f16 = hvx_vec_repl_f16(Q6_Vh_vsplat_R(*(const int16_t *)&scale_a1_val)); + HVX_VectorPair p_scale_a0_f32 = hvx_vec_f16_to_f32(v_scale_a0_f16); + HVX_VectorPair p_scale_a1_f32 = hvx_vec_f16_to_f32(v_scale_a1_f16); + HVX_Vector v_scale_a0 = Q6_V_lo_W(p_scale_a0_f32); + HVX_Vector v_scale_a1 = Q6_V_lo_W(p_scale_a1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} diff --git a/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h new file mode 100644 index 0000000000..bcb0b8f9e4 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-mm-kernels-tiled.h @@ -0,0 +1,1140 @@ +// Dynamic quantizers that produce tiled activations + +static inline void quantize_block_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); + HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); + HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); + HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + + HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); + HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 + HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); + HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); + + HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); + HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + const HVX_Vector ones = Q6_Vb_vsplat_R(1); + HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); + v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); + + float vmax0[32] __attribute__((aligned(128))); + float vmax1[32] __attribute__((aligned(128))); + float vmax2[32] __attribute__((aligned(128))); + float vmax3[32] __attribute__((aligned(128))); + int32_t sums[32] __attribute__((aligned(128))); + + hvx_vec_store_u(vmax0, 128, vmax0_sf); + hvx_vec_store_u(vmax1, 128, vmax1_sf); + hvx_vec_store_u(vmax2, 128, vmax2_sf); + hvx_vec_store_u(vmax3, 128, vmax3_sf); + hvx_vec_store_u(sums, 128, v_sums); + + float d0 = vmax0[0] / 127.0f; + float d1 = vmax1[0] / 127.0f; + float d2 = vmax2[0] / 127.0f; + float d3 = vmax3[0] / 127.0f; + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + __fp16 scale_h, offset_h; + if (b == 0) { + scale_h = (__fp16) d0; + offset_h = (__fp16) (sums[0] * d0); + } else if (b == 1) { + scale_h = (__fp16) d1; + offset_h = (__fp16) (sums[8] * d1); + } else if (b == 2) { + scale_h = (__fp16) d2; + offset_h = (__fp16) (sums[16] * d2); + } else { + scale_h = (__fp16) d3; + offset_h = (__fp16) (sums[24] * d3); + } + + HVX_Vector r_scale = Q6_Vh_vsplat_R(*(int16_t *)&scale_h); + HVX_Vector r_offset = Q6_Vh_vsplat_R(*(int16_t *)&offset_h); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1280); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + dst[9] = r_offset; + } +} + +static inline void quantize_block_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y_block) { + assert((unsigned long) x % 128 == 0); + assert((unsigned long) y_block % 128 == 0); + + HVX_Vector * vx = (HVX_Vector *) x; + HVX_Vector zero = Q6_V_vzero(); + + HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); + HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); + HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); + HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); + + HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); + HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); + + HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); + + HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); + HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); + + HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); + vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); + vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); + + HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); + HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); + HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); + + HVX_Vector r_scale = hvx_vec_repl_f16(vd_hf); + + static const uint8_t __attribute__((aligned(128))) repl[128] = { + 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, + }; + HVX_Vector v_repl_ctrl = * (const HVX_Vector *) repl; + + for (int b = 0; b < 4; b++) { + HVX_Vector v_act = Q6_V_vror_VR(vx_i8, b * 32); + + HVX_Vector r0 = Q6_V_vdelta_VV(v_act, v_repl_ctrl); + HVX_Vector r1 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 4), v_repl_ctrl); + HVX_Vector r2 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 8), v_repl_ctrl); + HVX_Vector r3 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 12), v_repl_ctrl); + HVX_Vector r4 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 16), v_repl_ctrl); + HVX_Vector r5 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 20), v_repl_ctrl); + HVX_Vector r6 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 24), v_repl_ctrl); + HVX_Vector r7 = Q6_V_vdelta_VV(Q6_V_vror_VR(v_act, 28), v_repl_ctrl); + + HVX_Vector * restrict dst = (HVX_Vector *) (y_block + b * 1152); + dst[0] = r0; + dst[1] = r1; + dst[2] = r2; + dst[3] = r3; + dst[4] = r4; + dst[5] = r5; + dst[6] = r6; + dst[7] = r7; + dst[8] = r_scale; + } +} + +static void quantize_row_f32_q8_0_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1152; + quantize_block_f32_q8_0_tiled(x + i * qk, y_block); + } +} + +static void quantize_row_f32_q8_1_tiled(float * restrict x, uint8_t * restrict y, uint32_t k) { + assert(k % 32 == 0); + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (k + qk - 1) / qk; + + for (uint32_t i = 0; i < nb; i++) { + uint8_t * restrict y_block = y + i * 4 * 1280; + quantize_block_f32_q8_1_tiled(x + i * qk, y_block); + } +} + +// Dot kernels & helpers that consume tiled activations + +static inline HVX_Vector hvx_vec_mul_f16_f16_to_f32_lower32(HVX_Vector v1, HVX_Vector v2) { +#if __HVX_ARCH__ >= 79 + HVX_VectorPair p = Q6_Wsf_vmpy_VhfVhf(v1, v2); + return Q6_V_lo_W(Q6_W_vshuff_VVR(Q6_V_hi_W(p), Q6_V_lo_W(p), -4)); +#else + HVX_VectorPair p = Q6_Wqf32_vmpy_VhfVhf(v1, v2); + HVX_Vector hi = Q6_Vsf_equals_Vqf32(Q6_V_hi_W(p)); + HVX_Vector lo = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(p)); + return Q6_V_lo_W(Q6_W_vshuff_VVR(hi, lo, -4)); +#endif +} + +static inline HVX_Vector unpack_and_interleave_4bit(HVX_Vector v_a, HVX_Vector v_b, HVX_Vector mask_h4) { + HVX_Vector v_W0 = Q6_V_vand_VV(v_a, mask_h4); + HVX_Vector v_W1 = Q6_Vub_vlsr_VubR(v_a, 4); + HVX_Vector v_W2 = Q6_V_vand_VV(v_b, mask_h4); + HVX_Vector v_W3 = Q6_Vub_vlsr_VubR(v_b, 4); + + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_W1, v_W0, -1); + HVX_VectorPair v23_pair = Q6_W_vshuff_VVR(v_W3, v_W2, -1); + HVX_VectorPair v0123_pair = Q6_W_vshuff_VVR(Q6_V_lo_W(v23_pair), Q6_V_lo_W(v01_pair), -2); + return Q6_V_lo_W(v0123_pair); +} + +static inline HVX_VectorPair unpack_and_interleave_4bit_x2(HVX_Vector v_src, HVX_Vector mask_h4) { + HVX_Vector v_lo = Q6_V_vand_VV(v_src, mask_h4); + HVX_Vector v_hi = Q6_Vub_vlsr_VubR(v_src, 4); + HVX_VectorPair v01_pair = Q6_W_vshuff_VVR(v_hi, v_lo, -1); + HVX_Vector v01_lo = Q6_V_lo_W(v01_pair); + HVX_Vector v01_hi = Q6_V_hi_W(v01_pair); + + HVX_Vector v23_lo = Q6_V_valign_VVR(v01_hi, v01_lo, 64); + HVX_Vector v_W0 = Q6_V_lo_W(Q6_W_vshuff_VVR(v23_lo, v01_lo, -2)); + + HVX_Vector v67_lo = Q6_V_valign_VVR(v01_lo, v01_hi, 64); + HVX_Vector v_W1 = Q6_V_lo_W(Q6_W_vshuff_VVR(v67_lo, v01_hi, -2)); + + return Q6_W_vcombine_VV(v_W1, v_W0); +} + +static inline HVX_Vector accum_4bit_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_Vector accum_4bit_32x1_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act[i * 2 + 1]); + } + + return Q6_Vw_vadd_VwVw(v_sum0, v_sum1); +} + +static inline HVX_VectorPair accum_4bit_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector i8 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v_W_pair), i8); + HVX_Vector v_W1 = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v_W_pair), i8); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_VectorPair accum_4bit_32x2_lut( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1, + HVX_Vector mask_h4, + HVX_Vector lut +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + + #pragma unroll + for (int i = 0; i < 4; i++) { + HVX_VectorPair v_W_pair = unpack_and_interleave_4bit_x2(vptr[i], mask_h4); + HVX_Vector v_W0 = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v_W_pair), lut, 0); + HVX_Vector v_W1 = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v_W_pair), lut, 0); + + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W0, v_act0[i * 2 + 0]); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W1, v_act0[i * 2 + 1]); + + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W0, v_act1[i * 2 + 0]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W1, v_act1[i * 2 + 1]); + } + + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static inline HVX_Vector accum_q8_0_32x1( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act +) { + HVX_Vector v_sum = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum = Q6_Vw_vrmpyacc_VwVbVb(v_sum, v_W, v_act[g]); + } + return v_sum; +} + +static inline HVX_VectorPair accum_q8_0_32x2( + const HVX_Vector * restrict vptr, + const HVX_Vector * restrict v_act0, + const HVX_Vector * restrict v_act1 +) { + HVX_Vector v_sum0 = Q6_V_vzero(); + HVX_Vector v_sum1 = Q6_V_vzero(); + #pragma unroll + for (int g = 0; g < 8; g++) { + HVX_Vector v_rot = Q6_V_vror_VR(vptr[g], 64); + HVX_Vector v_W = Q6_V_lo_W(Q6_W_vshuff_VVR(v_rot, vptr[g], -2)); + v_sum0 = Q6_Vw_vrmpyacc_VwVbVb(v_sum0, v_W, v_act0[g]); + v_sum1 = Q6_Vw_vrmpyacc_VwVbVb(v_sum1, v_W, v_act1[g]); + } + return Q6_W_vcombine_VV(v_sum1, v_sum0); +} + +static void tiled_vec_dot_q4_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, i8); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector i8 = Q6_Vb_vsplat_R(8); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, i8); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, i8); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, i8); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q4_1_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1280); + + HVX_Vector v_sum = accum_4bit_32x1(vptr, v_act, Q6_V_vzero()); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_sum_a = v_act[9]; + + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a); + HVX_Vector v_offset_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a); + + HVX_Vector v_scaled_dot = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + HVX_Vector v_sum_scaled = hvx_vec_add_f32_f32(v_scaled_dot, v_offset_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q4_1_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1280); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1280); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1280); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1280); + + HVX_VectorPair v_sums0 = accum_4bit_32x2(vptr0, v_act0_0, v_act1_0, Q6_V_vzero()); + HVX_VectorPair v_sums1 = accum_4bit_32x2(vptr1, v_act0_1, v_act1_1, Q6_V_vzero()); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_offset0 = vptr0[4]; + HVX_VectorPair p_deal0 = Q6_W_vdeal_VVR(v_scale_offset0, v_scale_offset0, -2); + HVX_Vector v_scale0 = Q6_V_lo_W(p_deal0); + HVX_Vector v_offset0 = Q6_V_hi_W(p_deal0); + + HVX_Vector v_scale_offset1 = vptr1[4]; + HVX_VectorPair p_deal1 = Q6_W_vdeal_VVR(v_scale_offset1, v_scale_offset1, -2); + HVX_Vector v_scale1 = Q6_V_lo_W(p_deal1); + HVX_Vector v_offset1 = Q6_V_hi_W(p_deal1); + + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_sum_a_c0_0 = v_act0_0[9]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_sum_a_c1_0 = v_act1_0[9]; + + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_sum_a_c0_1 = v_act0_1[9]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + HVX_Vector v_sum_a_c1_1 = v_act1_1[9]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c0_0); + HVX_Vector v_offset_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale0, v_scale_a_c1_0); + HVX_Vector v_offset_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset0, v_sum_a_c1_0); + + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c0_1); + HVX_Vector v_offset_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale1, v_scale_a_c1_1); + HVX_Vector v_offset_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset1, v_sum_a_c1_1); + + HVX_Vector v_scaled_dot_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_add_f32_f32(v_scaled_dot_c0_0, v_offset_comb_c0_0); + + HVX_Vector v_scaled_dot_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_add_f32_f32(v_scaled_dot_c1_0, v_offset_comb_c1_0); + + HVX_Vector v_scaled_dot_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_add_f32_f32(v_scaled_dot_c0_1, v_offset_comb_c0_1); + + HVX_Vector v_scaled_dot_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_add_f32_f32(v_scaled_dot_c1_1, v_offset_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1280); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1280); + + HVX_VectorPair v_sums = accum_4bit_32x2(vptr, v_act0, v_act1, Q6_V_vzero()); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_offset = vptr[4]; + HVX_VectorPair p_deal = Q6_W_vdeal_VVR(v_scale_offset, v_scale_offset, -2); + HVX_Vector v_scale = Q6_V_lo_W(p_deal); + HVX_Vector v_offset = Q6_V_hi_W(p_deal); + + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_sum_a_c0 = v_act0[9]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + HVX_Vector v_sum_a_c1 = v_act1[9]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c0); + HVX_Vector v_offset_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale, v_scale_a_c1); + HVX_Vector v_offset_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_offset, v_sum_a_c1); + + HVX_Vector v_scaled_dot_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c0 = hvx_vec_add_f32_f32(v_scaled_dot_c0, v_offset_comb_c0); + + HVX_Vector v_scaled_dot_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + HVX_Vector v_sum_scaled_c1 = hvx_vec_add_f32_f32(v_scaled_dot_c1, v_offset_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_q8_0_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_q8_0_32x1(vptr, v_act); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_q8_0_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 1152); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 1152); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_q8_0_32x2(vptr0, v_act0_0, v_act1_0); + HVX_VectorPair v_sums1 = accum_q8_0_32x2(vptr1, v_act0_1, v_act1_1); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[8]; + HVX_Vector v_scale_w1 = vptr1[8]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 1152); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_q8_0_32x2(vptr, v_act0, v_act1); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[8]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_iq4nl_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a = v_act[8]; + HVX_Vector v_scale_comb = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_iq4nl_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = vptr0[4]; + HVX_Vector v_scale_w1 = vptr1[4]; + HVX_Vector v_scale_a_c0_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_1 = v_act1_1[8]; + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = vptr[4]; + HVX_Vector v_scale_a_c0 = v_act0[8]; + HVX_Vector v_scale_a_c1 = v_act1[8]; + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f16_f16_to_f32_lower32(v_scale_w, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static void tiled_vec_dot_mxfp4_32x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y_q = vy; + + HVX_Vector v_sum_float = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + for (uint32_t kt = 0; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act = (const HVX_Vector *) (y_q + kt * 1152); + + HVX_Vector v_sum = accum_4bit_32x1_lut(vptr, v_act, mask_h4, lut); + HVX_Vector v_sum_sf = Q6_Vsf_equals_Vw(v_sum); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_f16 = v_act[8]; + HVX_VectorPair p_scale_a_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_f16); + HVX_Vector v_scale_a = Q6_V_lo_W(p_scale_a_f32); + + HVX_Vector v_scale_comb = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a); + HVX_Vector v_sum_scaled = hvx_vec_mul_f32_f32(v_sum_sf, v_scale_comb); + + v_sum_float = hvx_vec_add_f32_f32(v_sum_float, v_sum_scaled); + } + + v_sum_float = hvx_vec_mul_f32_f32(v_sum_float, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s, valid_rows * sizeof(float), v_sum_float); +} + +static void tiled_vec_dot_mxfp4_32x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx, const void * restrict vy0, const void * restrict vy1, uint32_t valid_rows) { + const uint8_t * restrict tile_ptr = vx; + const uint8_t * restrict y0_q = vy0; + const uint8_t * restrict y1_q = vy1; + + HVX_Vector v_sum_float_c0 = Q6_V_vzero(); + HVX_Vector v_sum_float_c1 = Q6_V_vzero(); + HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + + uint32_t n_k_tiles = n / 32; + uint32_t kt = 0; + for (; kt + 1 < n_k_tiles; kt += 2) { + const HVX_Vector * restrict vptr0 = (const HVX_Vector *) (tile_ptr + (kt + 0) * 640); + const HVX_Vector * restrict v_act0_0 = (const HVX_Vector *) (y0_q + (kt + 0) * 1152); + const HVX_Vector * restrict v_act1_0 = (const HVX_Vector *) (y1_q + (kt + 0) * 1152); + + const HVX_Vector * restrict vptr1 = (const HVX_Vector *) (tile_ptr + (kt + 1) * 640); + const HVX_Vector * restrict v_act0_1 = (const HVX_Vector *) (y0_q + (kt + 1) * 1152); + const HVX_Vector * restrict v_act1_1 = (const HVX_Vector *) (y1_q + (kt + 1) * 1152); + + HVX_VectorPair v_sums0 = accum_4bit_32x2_lut(vptr0, v_act0_0, v_act1_0, mask_h4, lut); + HVX_VectorPair v_sums1 = accum_4bit_32x2_lut(vptr1, v_act0_1, v_act1_1, mask_h4, lut); + + HVX_Vector v_sum_c0_0 = Q6_V_lo_W(v_sums0); + HVX_Vector v_sum_c1_0 = Q6_V_hi_W(v_sums0); + HVX_Vector v_sum_c0_1 = Q6_V_lo_W(v_sums1); + HVX_Vector v_sum_c1_1 = Q6_V_hi_W(v_sums1); + + HVX_Vector v_sum_sf_c0_0 = Q6_Vsf_equals_Vw(v_sum_c0_0); + HVX_Vector v_sum_sf_c1_0 = Q6_Vsf_equals_Vw(v_sum_c1_0); + HVX_Vector v_sum_sf_c0_1 = Q6_Vsf_equals_Vw(v_sum_c0_1); + HVX_Vector v_sum_sf_c1_1 = Q6_Vsf_equals_Vw(v_sum_c1_1); + + HVX_Vector v_scale_w0 = hvx_vmem(tile_ptr + (kt + 0) * 640 + 512); + HVX_Vector r0_d0 = Q6_V_vdelta_VV(v_scale_w0, expand); + r0_d0 = Q6_V_vand_VV(r0_d0, e8m0_mask); + HVX_Vector v_scale_w_f32_0 = Q6_Vw_vasl_VwR(r0_d0, 23); + + HVX_Vector v_scale_w1 = hvx_vmem(tile_ptr + (kt + 1) * 640 + 512); + HVX_Vector r0_d1 = Q6_V_vdelta_VV(v_scale_w1, expand); + r0_d1 = Q6_V_vand_VV(r0_d1, e8m0_mask); + HVX_Vector v_scale_w_f32_1 = Q6_Vw_vasl_VwR(r0_d1, 23); + + HVX_Vector v_scale_a_c0_f16_0 = v_act0_0[8]; + HVX_Vector v_scale_a_c1_f16_0 = v_act1_0[8]; + HVX_Vector v_scale_a_c0_f16_1 = v_act0_1[8]; + HVX_Vector v_scale_a_c1_f16_1 = v_act1_1[8]; + + HVX_VectorPair p_scale_a_c0_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_0); + HVX_VectorPair p_scale_a_c1_f32_0 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_0); + HVX_VectorPair p_scale_a_c0_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16_1); + HVX_VectorPair p_scale_a_c1_f32_1 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16_1); + + HVX_Vector v_scale_a_c0_0 = Q6_V_lo_W(p_scale_a_c0_f32_0); + HVX_Vector v_scale_a_c1_0 = Q6_V_lo_W(p_scale_a_c1_f32_0); + HVX_Vector v_scale_a_c0_1 = Q6_V_lo_W(p_scale_a_c0_f32_1); + HVX_Vector v_scale_a_c1_1 = Q6_V_lo_W(p_scale_a_c1_f32_1); + + HVX_Vector v_scale_comb_c0_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c0_0); + HVX_Vector v_scale_comb_c1_0 = hvx_vec_mul_f32_f32(v_scale_w_f32_0, v_scale_a_c1_0); + HVX_Vector v_scale_comb_c0_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c0_1); + HVX_Vector v_scale_comb_c1_1 = hvx_vec_mul_f32_f32(v_scale_w_f32_1, v_scale_a_c1_1); + + HVX_Vector v_sum_scaled_c0_0 = hvx_vec_mul_f32_f32(v_sum_sf_c0_0, v_scale_comb_c0_0); + HVX_Vector v_sum_scaled_c1_0 = hvx_vec_mul_f32_f32(v_sum_sf_c1_0, v_scale_comb_c1_0); + HVX_Vector v_sum_scaled_c0_1 = hvx_vec_mul_f32_f32(v_sum_sf_c0_1, v_scale_comb_c0_1); + HVX_Vector v_sum_scaled_c1_1 = hvx_vec_mul_f32_f32(v_sum_sf_c1_1, v_scale_comb_c1_1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, hvx_vec_add_f32_f32(v_sum_scaled_c0_0, v_sum_scaled_c0_1)); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, hvx_vec_add_f32_f32(v_sum_scaled_c1_0, v_sum_scaled_c1_1)); + } + + for (; kt < n_k_tiles; kt++) { + const HVX_Vector * restrict vptr = (const HVX_Vector *) (tile_ptr + kt * 640); + const HVX_Vector * restrict v_act0 = (const HVX_Vector *) (y0_q + kt * 1152); + const HVX_Vector * restrict v_act1 = (const HVX_Vector *) (y1_q + kt * 1152); + + HVX_VectorPair v_sums = accum_4bit_32x2_lut(vptr, v_act0, v_act1, mask_h4, lut); + HVX_Vector v_sum_c0 = Q6_V_lo_W(v_sums); + HVX_Vector v_sum_c1 = Q6_V_hi_W(v_sums); + + HVX_Vector v_sum_sf_c0 = Q6_Vsf_equals_Vw(v_sum_c0); + HVX_Vector v_sum_sf_c1 = Q6_Vsf_equals_Vw(v_sum_c1); + + HVX_Vector v_scale_w = hvx_vmem(tile_ptr + kt * 640 + 512); + HVX_Vector r0_d = Q6_V_vdelta_VV(v_scale_w, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + HVX_Vector v_scale_w_f32 = Q6_Vw_vasl_VwR(r0_d, 23); + + HVX_Vector v_scale_a_c0_f16 = v_act0[8]; + HVX_Vector v_scale_a_c1_f16 = v_act1[8]; + + HVX_VectorPair p_scale_a_c0_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c0_f16); + HVX_VectorPair p_scale_a_c1_f32 = hvx_vec_f16_to_f32_shuff(v_scale_a_c1_f16); + + HVX_Vector v_scale_a_c0 = Q6_V_lo_W(p_scale_a_c0_f32); + HVX_Vector v_scale_a_c1 = Q6_V_lo_W(p_scale_a_c1_f32); + + HVX_Vector v_scale_comb_c0 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c0); + HVX_Vector v_scale_comb_c1 = hvx_vec_mul_f32_f32(v_scale_w_f32, v_scale_a_c1); + + HVX_Vector v_sum_scaled_c0 = hvx_vec_mul_f32_f32(v_sum_sf_c0, v_scale_comb_c0); + HVX_Vector v_sum_scaled_c1 = hvx_vec_mul_f32_f32(v_sum_sf_c1, v_scale_comb_c1); + + v_sum_float_c0 = hvx_vec_add_f32_f32(v_sum_float_c0, v_sum_scaled_c0); + v_sum_float_c1 = hvx_vec_add_f32_f32(v_sum_float_c1, v_sum_scaled_c1); + } + + v_sum_float_c0 = hvx_vec_mul_f32_f32(v_sum_float_c0, hvx_vec_splat_f32(0.5f)); + v_sum_float_c1 = hvx_vec_mul_f32_f32(v_sum_float_c1, hvx_vec_splat_f32(0.5f)); + + hvx_vec_store_u(s0, valid_rows * sizeof(float), v_sum_float_c0); + hvx_vec_store_u(s1, valid_rows * sizeof(float), v_sum_float_c1); +} + +static inline void quantize_f32_q8_0_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_0_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_1_tiled_kernel( + const uint8_t * restrict src_data, + uint8_t * restrict dst_data, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t nrows, + size_t src_row_size, + size_t dst_row_size +) { + const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0_TILED * sizeof(float)); + hvx_splat_f32_a(tmp_data, 0.0f, src_row_size_padded / sizeof(float)); + + for (uint32_t i = 0; i < nrows; ++i) { + hex_l2fetch(src_data, src_row_size, src_row_size, 2); + hvx_copy_f32_aa(tmp_data, src_data, ne0); + + quantize_row_f32_q8_1_tiled((float *) tmp_data, dst_data, ne0); + dst_data += dst_row_size; + src_data += src_row_size; + } +} + +static inline void quantize_f32_q8_0_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1152; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_0_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} + +static inline void quantize_f32_q8_1_tiled_block_kernel( + const float * restrict src, + uint8_t * restrict dst, + uint8_t * restrict tmp_data, + uint32_t ne0, + uint32_t ib_first, + uint32_t ib_last, + size_t src_row_size, + size_t dst_row_size, + uint32_t r, + uint32_t c +) { + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne0 + qk - 1) / qk; + + for (uint32_t ib = ib_first; ib < ib_last; ++ib) { + const uint8_t * restrict src_ptr = (const uint8_t *) src + r * src_row_size + c * qk * sizeof(float); + uint8_t * restrict dst_ptr = dst + r * dst_row_size + c * 4 * 1280; + + hex_l2fetch(src_ptr, qk * sizeof(float), qk * sizeof(float), 1); + + if (c == nb - 1) { + uint32_t active_elements = ne0 - c * qk; + hvx_splat_f32_a(tmp_data, 0.0f, qk); + hvx_copy_f32_aa(tmp_data, src_ptr, active_elements); + } else { + hvx_copy_f32_aa(tmp_data, src_ptr, qk); + } + + quantize_block_f32_q8_1_tiled((float *) tmp_data, dst_ptr); + + c++; + if (c == nb) { + c = 0; + r++; + } + } +} diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 53ab33c07b..d76512ea4a 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -361,7 +361,7 @@ static void vtcm_free(struct htp_context * ctx) { static void htp_packet_callback(dspqueue_t queue, int error, void * context); static void htp_error_callback(dspqueue_t queue, int error, void * context); -AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_queue_id, uint32 n_hvx, uint32 use_hmx, uint64_t max_vmem) { +AEEResult htp_iface_start(remote_handle64 handle, uint32_t sess_id, uint64_t dsp_queue_id, uint32_t n_hvx, uint32_t n_hmx, uint64_t max_vmem) { struct htp_context * ctx = (struct htp_context *) handle; if (!ctx) { @@ -395,10 +395,9 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que return AEE_ENOMEMORY; } -#ifdef HTP_HAS_HMX - ctx->hmx_enabled = use_hmx; + ctx->hmx_enabled = n_hmx; ctx->hmx_queue = NULL; - if (use_hmx) { + if (n_hmx) { ctx->hmx_queue = hmx_queue_create(16, ctx->vtcm_rctx); if (ctx->hmx_queue) { ctx->hmx_queue->trace = &ctx->trace[HTP_MAX_NTHREADS]; @@ -407,8 +406,7 @@ AEEResult htp_iface_start(remote_handle64 handle, uint32 sess_id, uint64 dsp_que ctx->hmx_enabled = false; } } - FARF(HIGH, "HMX %s (use_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", use_hmx); -#endif + FARF(HIGH, "HMX %s (n_hmx=%d)", ctx->hmx_enabled ? "enabled" : "disabled", n_hmx); qurt_sysenv_max_hthreads_t hw_threads; qurt_sysenv_get_max_hw_threads(&hw_threads); @@ -481,13 +479,11 @@ AEEResult htp_iface_stop(remote_handle64 handle) { dma_queue_delete(ctx->dma[i]); } -#ifdef HTP_HAS_HMX if (ctx->hmx_queue) { hmx_queue_delete(ctx->hmx_queue); ctx->hmx_queue = NULL; } ctx->hmx_enabled = false; -#endif vtcm_free(ctx); @@ -500,6 +496,36 @@ AEEResult htp_iface_stop(remote_handle64 handle) { return AEE_SUCCESS; } +AEEResult htp_iface_hwinfo(remote_handle64 handle, uint32_t * n_threads, uint32_t * n_hvx, uint32_t * n_hmx, uint64_t * vtcm_size) { + (void)handle; + if (!n_threads || !n_hvx || !n_hmx || !vtcm_size) { + return AEE_EBADPARM; + } + + qurt_sysenv_max_hthreads_t hw_threads; + qurt_sysenv_get_max_hw_threads(&hw_threads); + uint32_t hw_nhvx = (qurt_hvx_get_units() >> 8) & 0xFF; + + uint32_t n_hvx_val = hw_nhvx; + if (n_hvx_val > hw_threads.max_hthreads) { + n_hvx_val = hw_threads.max_hthreads; + } + if (n_hvx_val > HTP_MAX_NTHREADS) { + n_hvx_val = HTP_MAX_NTHREADS; + } + + // for now we force n_threads == n_hvx + *n_threads = n_hvx_val; + *n_hvx = n_hvx_val; + *n_hmx = 1; + + uint32_t vtcm_sz = 8 * 1024 * 1024; // 8MB default fallback + HAP_compute_res_query_VTCM(0, (unsigned int *)&vtcm_sz, NULL, NULL, NULL); + *vtcm_size = vtcm_sz; + + return AEE_SUCCESS; +} + static void htp_error_callback(dspqueue_t queue, int error, void * context) { // No errors expected on the DSP. FARF(ERROR, "Error callback: 0x%08x", (unsigned) error); @@ -554,6 +580,12 @@ static int execute_op(struct htp_ops_context * octx) { case HTP_OP_MUL_MAT_ID: return op_matmul_id(octx); + case HTP_OP_MUL_MAT_QKV: + return op_matmul_qkv(octx); + + case HTP_OP_MUL_MAT_FFN: + return op_matmul_ffn(octx); + case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: @@ -762,8 +794,9 @@ static void prep_tensors(struct htp_context *ctx, struct htp_buf_desc *bufs, str } } -static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { +static int proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, uint32_t idx, struct htp_op_desc * op) { memcpy(octx->op_params, op->params, sizeof(octx->op_params)); + memcpy(octx->kernel_params, op->kernel_params, sizeof(octx->kernel_params)); octx->flags = op->flags; octx->op = op->opcode; @@ -785,22 +818,41 @@ static void proc_op_req(struct htp_ops_context * octx, struct htp_tensor *tens, src->ne[0], src->ne[1], src->ne[3], src->ne[3]); } - // Prep output tensor - struct htp_tensor *dst = tens + op->dst; + // Prep output tensors + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + uint16_t dst_idx = op->dst[i]; + if (dst_idx == 0xffff) { + octx->dsts[i] = NULL; + continue; + } + struct htp_tensor *dst = tens + dst_idx; + octx->dsts[i] = dst; - octx->dst = dst; + FARF(HIGH, "prep-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, dst_idx, (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } - FARF(HIGH, "prep-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + int status = execute_op(octx); - (void) execute_op(octx); + octx->src0_spad.src = NULL; + octx->src1_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->src3_spad.src = NULL; + octx->dst_spad.src = NULL; // flush buffers on output - hex_l2flush((void *) dst->data, dst->size); - dst->flags |= HTP_TENSOR_FLUSHED; + for (uint32_t i = 0; i < HTP_OP_MAX_OUTPUTS; i++) { + if (octx->dsts[i]) { + struct htp_tensor *dst = (struct htp_tensor *)octx->dsts[i]; + hex_l2flush((void *) dst->data, dst->size); + dst->flags |= HTP_TENSOR_FLUSHED; - FARF(HIGH, "post-dst #%u: data %p size %u : %u:%u:%u:%u", op->dst, (void*) dst->data, dst->size, - dst->ne[0], dst->ne[1], dst->ne[3], dst->ne[3]); + FARF(HIGH, "post-dst[%u] #%u: data %p size %u : %u:%u:%u:%u", i, op->dst[i], (void*) dst->data, dst->size, + dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3]); + } + } + + return status; } #define DSPQUEUE_POLL_TIMEOUT_USEC 100 @@ -892,20 +944,26 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { } } + int op_status = HTP_STATUS_OK; + uint32_t op_wakeup = n_ops / 2; // half-way throgh the batch + for (uint32_t i=0; i < n_ops; i++) { struct profile_data prof; - if (i == (n_ops-1)) { - // wake up the host before starting the last op + if (i == op_wakeup) { dspqueue_write_early_wakeup_noblock(queue, 0, 0); } profile_start(ctx->profiler, &prof); - proc_op_req(octx, tens, i, &ops[i]); + op_status = proc_op_req(octx, tens, i, &ops[i]); profile_stop(ctx->profiler, &prof); + if (op_status != HTP_STATUS_OK) { + break; + } + if (ctx->profiler) { pds[i].opcode = ops[i].opcode; pds[i].usecs = prof.usecs; @@ -919,7 +977,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_opbatch_rsp rsp; rsp.id = req.id; - rsp.status = HTP_STATUS_OK; + rsp.status = op_status; rsp.n_bufs = n_bufs; rsp.n_tensors = n_tens; rsp.n_ops = n_ops; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index 8e016c1be5..81a0ffbebb 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -5,6 +5,7 @@ #include #include +#include #include #include @@ -17,33 +18,50 @@ #include "ggml-common.h" #include "htp-ctx.h" #include "htp-ops.h" -#include "htp-ops.h" -#include "hmx-ops.h" +#include "matmul-ops.h" +#include "vtcm-utils.h" -#define MM_SPAD_SRC0_NROWS 16 -#define MM_SPAD_SRC1_NROWS 16 -#define MM_SPAD_DST_NROWS 2 +typedef struct { + float *dst; + const float *activation; + const __fp16 *weight; + int m; + int k; + int n; + int act_stride; + int weight_stride; + int dst_stride; + int ne02; + int ne03; + int ne12; + int ne13; + size_t src0_nb2; + size_t src0_nb3; + size_t src1_nb2; + size_t src1_nb3; + size_t dst_nb2; + size_t dst_nb3; +} hmx_mm_f16_f32_batched_params_t; -struct htp_matmul_context { +struct htp_mm_context { const char * type; struct htp_ops_context * octx; - void (*vec_dot_1x1)(const int n, float * restrict s0, + void (*vec_dot_1x1)(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vy0); - void (*vec_dot_2x1)(const int n, float * restrict s0, + void (*vec_dot_2x1)(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0); - void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + void (*vec_dot_2x2)(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1); - void (*vec_dot_4x1)(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0); + void (*vec_dot_32x1)(const uint32_t n, float * restrict s, + const void * restrict vx, + const void * restrict vy, uint32_t valid_rows); // Precomputed values uint32_t src0_nrows_per_thread; @@ -53,11 +71,37 @@ struct htp_matmul_context { struct fastdiv_values mm_div_ne1; struct fastdiv_values mm_div_r2; struct fastdiv_values mm_div_r3; + struct fastdiv_values mm_div_ne11; + + // Precomputed block-parallel quantization values + uint32_t quant_ib_first[MAX_NUM_WORKERS]; + uint32_t quant_ib_last[MAX_NUM_WORKERS]; + uint32_t quant_r[MAX_NUM_WORKERS]; + uint32_t quant_c[MAX_NUM_WORKERS]; // Fields for scattered mapping & HMX support in MUL_MAT_ID const uint32_t * matrix_row_counts; const struct mmid_row_mapping * matrix_rows; - bool hmx_eligible; + + // Dynamic VTCM pointers allocated sequentially + uint8_t * vtcm_src0; + uint8_t * vtcm_src1; + uint8_t * vtcm_src2; + uint8_t * vtcm_src3; + uint8_t * vtcm_dst; + + // Cached strides + uint32_t vtcm_src0_stride; + uint32_t vtcm_src1_stride; + uint32_t vtcm_src2_stride; + uint32_t vtcm_src3_stride; + + // Cached thread offsets/sizes + uint32_t vtcm_src0_size_per_thread; + uint32_t vtcm_src1_size_per_thread; + uint32_t vtcm_src2_size_per_thread; + uint32_t vtcm_src3_size_per_thread; + uint32_t vtcm_dst_size_per_thread; }; // vdelta control to expand first 32 e8m0 values into 32 uint32 elements @@ -89,2835 +133,6 @@ static const uint8_t __attribute__((aligned(VLEN))) kvalues_mxfp4_lut[] = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_iq4nlx4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_iq4nl_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i * 2 + 0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i * 2 + 1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -// q4x4x2 and q8x4x2 are the flat q4/8_0 formats where all quants are stored first followed by all scales - -static inline size_t q8x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * sizeof(__fp16), 128); -} - -static inline size_t q8_1x4x2_row_size(uint32_t ne) { - // ensures perfect alignment of quants and full row - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (ne + qk - 1) / qk; - return hex_round_up(ne + nb * 8 * 2 * sizeof(__fp16), 128); -} - -static inline HVX_Vector_x8 hvx_vec_load_q4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - // Convert uint4 to int4 (i.e. x - 8) - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vsub_VbVb(v0, i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(v1, i8); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vsub_VbVb(Q6_V_lo_W(v0_1_p), i8); - r.v[i*2+1] = Q6_Vb_vsub_VbVb(Q6_V_hi_W(v0_1_p), i8); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q4_1x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 : second 128 elements - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F ... - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static HVX_Vector_x8 hvx_vec_load_q4_1x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = v0; - r.v[i*2+1] = v1; - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_V_lo_W(v0_1_p); - r.v[i*2+1] = Q6_V_hi_W(v0_1_p); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0_1 = vptr[0]; // first 256 elements (128 bytes) - HVX_Vector v2_3 = vptr[1]; // ... - HVX_Vector v4_5 = vptr[2]; // ... - HVX_Vector v6_7 = vptr[3]; // ... - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 - HVX_Vector v2 = Q6_V_vand_VV(v2_3, mask_h4); // & 0x0F - HVX_Vector v3 = Q6_Vub_vlsr_VubR(v2_3, 4); // >> 4 - HVX_Vector v4 = Q6_V_vand_VV(v4_5, mask_h4); // & 0x0F - HVX_Vector v5 = Q6_Vub_vlsr_VubR(v4_5, 4); // >> 4 - HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F - HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8_partial(const uint8_t * restrict ptr, uint32_t n) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - const uint32_t qk = QK_Q4_0x4x2; // 256 - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); - const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - - HVX_Vector_x8 r; - uint32_t i = 0; - - #pragma unroll(2) - for (i=0; i < nb; i++) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : first 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : second 128 elements - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - } - - if (nloe) { - HVX_Vector v = vptr[i]; // 256 elements (128 bytes) - HVX_Vector v0 = Q6_V_vand_VV(v, mask_h4); // & 0x0F : even 128 elements - HVX_Vector v1 = Q6_Vub_vlsr_VubR(v, 4); // >> 4 : odd 128 elements - HVX_VectorPair v0_1_p = Q6_W_vshuff_VVR(v1, v0, -1); // zip even:odd:... - r.v[i*2+0] = Q6_Vb_vlut32_VbVbI(Q6_V_lo_W(v0_1_p), lut, 0); - r.v[i*2+1] = Q6_Vb_vlut32_VbVbI(Q6_V_hi_W(v0_1_p), lut, 0); - } - - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_full(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 128 vals - HVX_Vector v1 = vptr[1]; // ... - HVX_Vector v2 = vptr[2]; // ... - HVX_Vector v3 = vptr[3]; // ... - HVX_Vector v4 = vptr[4]; // ... - HVX_Vector v5 = vptr[5]; // ... - HVX_Vector v6 = vptr[6]; // ... - HVX_Vector v7 = vptr[7]; // ... - - HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; - return r; -} - -static inline HVX_Vector_x8 hvx_vec_load_q8x4x8_partial(const uint8_t * restrict ptr, uint32_t nloe) { - return hvx_vec_load_q8x4x8_full(ptr); -} - -// Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). -// Accumulate each block into a single int32 value. -// Return a single HVX vector with 32x int32 accumulators. -// This version is parameterized to support less than 1024 elements. -// if() checks are optimized out at compile time -- make sure to pass N as a constexpr. - -static inline HVX_Vector hvx_vec_rmpy_x8_n(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - HVX_Vector r0 = Q6_V_vzero(); - HVX_Vector r1 = Q6_V_vzero(); - HVX_Vector r2 = Q6_V_vzero(); - HVX_Vector r3 = Q6_V_vzero(); - HVX_Vector r4 = Q6_V_vzero(); - HVX_Vector r5 = Q6_V_vzero(); - HVX_Vector r6 = Q6_V_vzero(); - HVX_Vector r7 = Q6_V_vzero(); - - HVX_VectorPair p3; - HVX_VectorPair p2; - HVX_VectorPair p1; - HVX_VectorPair p0; - - if (n >= 128) { r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); } - if (n >= 256) { r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); } - if (n >= 384) { r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); } - if (n >= 512) { r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); } - if (n >= 640) { r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); } - if (n >= 768) { r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); } - if (n >= 896) { r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); } - if (n >= 1024) { r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 384) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - if (n >= 640) { p2 = Q6_W_vdeal_VVR(r5, r4, -4); } - if (n >= 896) { p3 = Q6_W_vdeal_VVR(r7, r6, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 384) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - if (n >= 640) { r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); } - if (n >= 896) { r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 640) { p1 = Q6_W_vdeal_VVR(r3, r2, -4); } - - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - if (n >= 640) { r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); } - - if (n >= 128) { p0 = Q6_W_vdeal_VVR(r1, r0, -4); } - if (n >= 128) { r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); } - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_full(HVX_Vector_x8 x, HVX_Vector_x8 y) { - HVX_Vector r0 = Q6_Vw_vrmpy_VbVb(x.v[0], y.v[0]); - HVX_Vector r1 = Q6_Vw_vrmpy_VbVb(x.v[1], y.v[1]); - HVX_Vector r2 = Q6_Vw_vrmpy_VbVb(x.v[2], y.v[2]); - HVX_Vector r3 = Q6_Vw_vrmpy_VbVb(x.v[3], y.v[3]); - HVX_Vector r4 = Q6_Vw_vrmpy_VbVb(x.v[4], y.v[4]); - HVX_Vector r5 = Q6_Vw_vrmpy_VbVb(x.v[5], y.v[5]); - HVX_Vector r6 = Q6_Vw_vrmpy_VbVb(x.v[6], y.v[6]); - HVX_Vector r7 = Q6_Vw_vrmpy_VbVb(x.v[7], y.v[7]); - - HVX_VectorPair p0 = Q6_W_vdeal_VVR(r1, r0, -4); - HVX_VectorPair p1 = Q6_W_vdeal_VVR(r3, r2, -4); - HVX_VectorPair p2 = Q6_W_vdeal_VVR(r5, r4, -4); - HVX_VectorPair p3 = Q6_W_vdeal_VVR(r7, r6, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - r2 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p2), Q6_V_hi_W(p2)); - r3 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p3), Q6_V_hi_W(p3)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - p1 = Q6_W_vdeal_VVR(r3, r2, -4); - - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - r1 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p1), Q6_V_hi_W(p1)); - - p0 = Q6_W_vdeal_VVR(r1, r0, -4); - r0 = Q6_Vw_vadd_VwVw(Q6_V_lo_W(p0), Q6_V_hi_W(p0)); - - return r0; -} - -static inline HVX_Vector hvx_vec_rmpy_x8_partial(HVX_Vector_x8 x, HVX_Vector_x8 y, unsigned int n) { - if (n >= 512) - return hvx_vec_rmpy_x8_full(x, y); - - return hvx_vec_rmpy_x8_partial(x, y, 512); -} - -static void vec_dot_q4_1x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales/offsets - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair dm_deal = Q6_W_vdeal_VVR(dm, dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4_1x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales/sums - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4_1x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4_1x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector ds = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_VectorPair ds_deal = Q6_W_vdeal_VVR(ds, ds, -2); - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds_deal)); - HVX_Vector vy_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r2_dm = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_VectorPair r2_dm_deal = Q6_W_vdeal_VVR(r2_dm, r2_dm, -2); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r2_dm_deal)); - HVX_Vector r2_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r2_dm_deal)); - - HVX_Vector r3_dm = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - HVX_VectorPair r3_dm_deal = Q6_W_vdeal_VVR(r3_dm, r3_dm, -2); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r3_dm_deal)); - HVX_Vector r3_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r3_dm_deal)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy_s))); - - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy_s))); - - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r2_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_m, vy_s))); - - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - HVX_Vector r3_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_m, vy_s))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ms = Q6_V_vand_QV(bmask, r0_ms); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r1_ms = Q6_V_vand_QV(bmask, r1_ms); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r2_ms = Q6_V_vand_QV(bmask, r2_ms); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r3_ms = Q6_V_vand_QV(bmask, r3_ms); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_ms); - - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_ms); - - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r2_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_ms); - - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - HVX_Vector r3_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_ms); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa_total, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa_total, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa_total, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa_total, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4_1x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2 * 2; // 32x (d, m) __fp16 = 128 bytes - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 4; // 32x (d, s) __fp16 = 128 bytes - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales/sums - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales/sums - - // Row sums (sf) - 4 accumulators for 2ร—2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4_1x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4_1x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector ds0 = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_VectorPair ds0_deal = Q6_W_vdeal_VVR(ds0, ds0, -2); - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds0_deal)); - HVX_Vector vy0_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds0_deal)); - - HVX_Vector ds1 = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_VectorPair ds1_deal = Q6_W_vdeal_VVR(ds1, ds1, -2); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(ds1_deal)); - HVX_Vector vy1_s = Q6_Vh_vshuff_Vh(Q6_V_hi_W(ds1_deal)); - - HVX_Vector r0_dm = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_VectorPair r0_dm_deal = Q6_W_vdeal_VVR(r0_dm, r0_dm, -2); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r0_dm_deal)); - HVX_Vector r0_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r0_dm_deal)); - - HVX_Vector r1_dm = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_VectorPair r1_dm_deal = Q6_W_vdeal_VVR(r1_dm, r1_dm, -2); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(Q6_V_lo_W(r1_dm_deal)); - HVX_Vector r1_m = Q6_Vh_vshuff_Vh(Q6_V_hi_W(r1_dm_deal)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy0_s))); - - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r0_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_m, vy1_s))); - - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c0_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy0_s))); - - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - HVX_Vector r1_c1_ms = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_m, vy1_s))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c0_ms = Q6_V_vand_QV(bmask, r0_c0_ms); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r0_c1_ms = Q6_V_vand_QV(bmask, r0_c1_ms); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c0_ms = Q6_V_vand_QV(bmask, r1_c0_ms); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r1_c1_ms = Q6_V_vand_QV(bmask, r1_c1_ms); - - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - HVX_Vector r0_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_ms); - HVX_Vector r0_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_ms); - HVX_Vector r1_c0_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_ms); - HVX_Vector r1_c1_fa_total = Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_ms); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa_total, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa_total, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa_total, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa_total, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elemements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2ร—2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_q8x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (qf32) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_q8x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_q8x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q8_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2ร—2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - // Zero out unused elements - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - -// ======== IQ4_NL x Q8_0 vec_dot kernels ======== -// Same structure as Q4_0 vec_dot but uses IQ4_NL LUT-based load (4-bit index -> int8 kvalue). -// Scale format is identical to Q4_0 (fp16 scales). - -static void vec_dot_iq4nlx4x2_q8x4x2_1x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_2x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_iq4nlx4x2_q8x4x2_4x1(const int n, - float * restrict s0, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vx2, - const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full(y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial(y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_iq4nlx4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_iq4nlx4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy_q, nloe)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r2_q, vy_q, nloe)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r3_q, vy_q, nloe)); - - HVX_Vector vy_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - HVX_Vector r2_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r2_x_d + i * x_dblk_size)); - HVX_Vector r3_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r3_x_d + i * x_dblk_size)); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy_d))); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy_d))); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r2_d, vy_d))); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r3_d, vy_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_iq4nlx4x2_q8x4x2_2x2(const int n, - float * restrict s0, - float * restrict s1, - const void * restrict vx0, - const void * restrict vx1, - const void * restrict vy0, - const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_Q4_0x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; - - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; - const uint32_t nloe = n % qk; - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial(y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial(y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_iq4nlx4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_iq4nlx4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); - HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); - HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); - HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); - - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (qf32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy_q, nloe)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - - // Zero-out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - } - - r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - - hvx_vec_store_u(s0, 4, r0_sum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - - // Multiply and accumulate into int32. - // Compute combined scale (fp32). - // Apply scale to acc and accumulate into the row sum (f32). - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - } - - HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(s0, 8, rsum); -} - -static void vec_dot_mxfp4x4x2_q8x4x2_4x1(const int n, float * restrict s0, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vx2, const void * restrict vx3, - const void * restrict vy0) { - assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vx2 % 128 == 0); - assert((unsigned long) vx3 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r2_x_q = ((const uint8_t *) vx2) + 0; // quants first - const uint8_t * restrict r2_x_d = ((const uint8_t *) vx2) + x_qrow_size; // then scales - const uint8_t * restrict r3_x_q = ((const uint8_t *) vx3) + 0; // quants first - const uint8_t * restrict r3_x_d = ((const uint8_t *) vx3) + x_qrow_size; // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales - - // Row sum (sf) - HVX_Vector r0_sum = Q6_V_vzero(); - HVX_Vector r1_sum = Q6_V_vzero(); - HVX_Vector r2_sum = Q6_V_vzero(); - HVX_Vector r3_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - int32_t nloe = n % qk; // num leftover elemements (must be signed) - - uint32_t i = 0; - for (; i < nb; i++) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_full( y_q + i * y_qblk_size); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_full(r2_x_q + i * x_qblk_size); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_full(r3_x_q + i * x_qblk_size); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - if (nloe) { - HVX_Vector_x8 vy_q = hvx_vec_load_q8x4x8_partial( y_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r2_q = hvx_vec_load_mxfp4x4x8_partial(r2_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r3_q = hvx_vec_load_mxfp4x4x8_partial(r3_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy_q)); - HVX_Vector r1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy_q)); - HVX_Vector r2_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r2_q, vy_q)); - HVX_Vector r3_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r3_q, vy_q)); - - HVX_Vector vy_d = *(const HVX_UVector *) (y_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - HVX_Vector r2_d = *(const HVX_UVector *) (r2_x_d + i * x_dblk_size); - HVX_Vector r3_d = *(const HVX_UVector *) (r3_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy_d), half)); - vy_d = Q6_Vsf_equals_Vqf32(vy_d); - - // Convert rX_d scales from e8m0 to fp32 - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - r2_d = Q6_V_vdelta_VV(r2_d, expand); - r2_d = Q6_V_vand_VV(r2_d, e8m0_mask); - r2_d = Q6_Vw_vasl_VwR(r2_d, 23); - r3_d = Q6_V_vdelta_VV(r3_d, expand); - r3_d = Q6_V_vand_VV(r3_d, e8m0_mask); - r3_d = Q6_Vw_vasl_VwR(r3_d, 23); - - HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); - HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); - HVX_Vector r2_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r2_d, vy_d)); - HVX_Vector r3_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r3_d, vy_d)); - - // Zero-out unused values - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_dd = Q6_V_vand_QV(bmask, r0_dd); - r1_dd = Q6_V_vand_QV(bmask, r1_dd); - r2_dd = Q6_V_vand_QV(bmask, r2_dd); - r3_dd = Q6_V_vand_QV(bmask, r3_dd); - r0_ia = Q6_V_vand_QV(bmask, r0_ia); - r1_ia = Q6_V_vand_QV(bmask, r1_ia); - r2_ia = Q6_V_vand_QV(bmask, r2_ia); - r3_ia = Q6_V_vand_QV(bmask, r3_ia); - - HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); - HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); - HVX_Vector r2_fa = Q6_Vqf32_vmpy_VsfVsf(r2_ia, r2_dd); - HVX_Vector r3_fa = Q6_Vqf32_vmpy_VsfVsf(r3_ia, r3_dd); - - r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum)); - r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum)); - r2_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r2_fa, r2_sum)); - r3_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r3_fa, r3_sum)); - } - - HVX_Vector_x4 rsum_in = { .v = { r0_sum, r1_sum, r2_sum, r3_sum } }; - HVX_Vector rsum = hvx_vec_reduce_sum_f32x4(rsum_in); - hvx_vec_store_u(s0, 16, rsum); -} - - -static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, - const void * restrict vx0, const void * restrict vx1, - const void * restrict vy0, const void * restrict vy1) { - assert(n % 32 == 0); - assert((unsigned long) vx0 % 128 == 0); - assert((unsigned long) vx1 % 128 == 0); - assert((unsigned long) vy0 % 128 == 0); - assert((unsigned long) vy1 % 128 == 0); - - const uint32_t qk = QK_MXFP4x4x2 * 4; - - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) - - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - - const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first - const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales - const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first - const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales - - // Row sums (sf) - 4 accumulators for 2ร—2 tile - HVX_Vector r0_c0_sum = Q6_V_vzero(); - HVX_Vector r0_c1_sum = Q6_V_vzero(); - HVX_Vector r1_c0_sum = Q6_V_vzero(); - HVX_Vector r1_c1_sum = Q6_V_vzero(); - - const uint32_t nb = n / qk; // num full blocks - const uint32_t nloe = n % qk; // num leftover elements - - uint32_t i = 0; - for (; i < nb; i++) { - // Load src1 columns (reused across both src0 rows) - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_full(y0_q + i * y_qblk_size); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_full(y1_q + i * y_qblk_size); - - // Load src0 rows (reused across both src1 columns) - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_full(r0_x_q + i * x_qblk_size); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_full(r1_x_q + i * x_qblk_size); - - // Compute 4 dot products: r0ร—c0, r0ร—c1, r1ร—c0, r1ร—c1 - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); - - // Load scales - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - // Compute combined scales - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Apply scales and accumulate - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Process leftovers - if (nloe) { - HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8_partial( y0_q + i * y_qblk_size, nloe); - HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8_partial( y1_q + i * y_qblk_size, nloe); - HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8_partial(r0_x_q + i * x_qblk_size, nloe); - HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8_partial(r1_x_q + i * x_qblk_size, nloe); - - HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy0_q, nloe)); - HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r0_q, vy1_q, nloe)); - HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy0_q, nloe)); - HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_partial(r1_q, vy1_q, nloe)); - - HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); - HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); - HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); - HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); - - // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving - HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 - vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); - vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); - vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); - vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); - - // Convert rX_d scales from e8m0 to fp32 - // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... - // Left shift with zero fill to create FP32 - // FIXME: might need to handle zero as a special case (see ggml-cpu code) - HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; - HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); - r0_d = Q6_V_vdelta_VV(r0_d, expand); - r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); - r0_d = Q6_Vw_vasl_VwR(r0_d, 23); - r1_d = Q6_V_vdelta_VV(r1_d, expand); - r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); - r1_d = Q6_Vw_vasl_VwR(r1_d, 23); - - HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); - HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); - HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); - HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); - - // Zero out unused scales - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); - r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); - r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); - r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); - r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); - r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); - r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); - r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); - r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); - - HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); - HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); - HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); - HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); - - r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); - r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); - r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); - r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); - } - - // Reduce and store results - HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); - HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); - - hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 - hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 -} - #if __HVX_ARCH__ < 79 #define HVX_OP_ADD_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(a, b)) #define HVX_OP_MUL_F32(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) @@ -2926,7 +141,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float #define HVX_OP_MUL_F32(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f32_f32_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -2954,7 +169,7 @@ static void vec_dot_f32_f32_aa_1x1(const int n, float * restrict s, const void * *s = hvx_vec_get_f32(hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f32_f32_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -2996,7 +211,7 @@ static void vec_dot_f32_f32_aa_2x1(const int n, float * restrict s0, s0[1] = va.fp32[1]; } -static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f32_f32_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3054,7 +269,7 @@ static void vec_dot_f32_f32_aa_2x2(const int n, float * restrict s0, float * res s1[1] = va1.fp32[1]; } -static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f32_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3088,7 +303,7 @@ static void vec_dot_f32_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -3115,7 +330,7 @@ static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(s, 4, hvx_vec_reduce_sum_f32(rsum)); } -static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, +static void vec_dot_f16_f16_aa_2x1(const uint32_t n, float * restrict s0, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3152,7 +367,7 @@ static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, +static void vec_dot_f16_f16_aa_2x2(const uint32_t n, float * restrict s0, float * restrict s1, const void * restrict vx0, const void * restrict vx1, const void * restrict vy0, const void * restrict vy1) { const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; @@ -3212,7 +427,7 @@ static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * res hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 } -static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_uu_1x1(const uint32_t n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -3242,7 +457,7 @@ static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const uint32_t n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -3295,65 +510,58 @@ static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * hvx_vec_store_u(&s[0], 4, rsum); } -#define htp_matmul_tensors_preamble \ - const struct htp_tensor * restrict src0 = octx->src[0]; \ - const struct htp_tensor * restrict src1 = octx->src[1]; \ - const struct htp_tensor * restrict src2 = octx->src[2]; \ - const struct htp_tensor * restrict dst = octx->dst; \ - struct htp_spad * restrict src0_spad = &octx->src0_spad; \ - struct htp_spad * restrict src1_spad = &octx->src1_spad; \ - struct htp_spad * restrict dst_spad = &octx->dst_spad; \ - \ - const uint32_t ne00 = src0->ne[0]; \ - const uint32_t ne01 = src0->ne[1]; \ - const uint32_t ne02 = src0->ne[2]; \ - const uint32_t ne03 = src0->ne[3]; \ - \ - const uint32_t ne10 = src1->ne[0]; \ - const uint32_t ne11 = src1->ne[1]; \ - const uint32_t ne12 = src1->ne[2]; \ - const uint32_t ne13 = src1->ne[3]; \ - \ - const uint32_t ne20 = src2->ne[0]; \ - const uint32_t ne21 = src2->ne[1]; \ - const uint32_t ne22 = src2->ne[2]; \ - const uint32_t ne23 = src2->ne[3]; \ - \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ - const uint32_t nb01 = src0->nb[1]; \ - const uint32_t nb02 = src0->nb[2]; \ - const uint32_t nb03 = src0->nb[3]; \ - \ - const uint32_t nb10 = src1->nb[0]; \ - const uint32_t nb11 = src1->nb[1]; \ - const uint32_t nb12 = src1->nb[2]; \ - const uint32_t nb13 = src1->nb[3]; \ - \ - const uint32_t nb0 = dst->nb[0]; \ - const uint32_t nb1 = dst->nb[1]; \ - const uint32_t nb2 = dst->nb[2]; \ +#define htp_matmul_tensors_preamble \ + const struct htp_tensor * restrict src0 = octx->src[0]; \ + const struct htp_tensor * restrict src1 = octx->src[1]; \ + const struct htp_tensor * restrict src2 = octx->src[2]; \ + const struct htp_tensor * restrict dst = octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t ne11 = src1->ne[1]; \ + const uint32_t ne12 = src1->ne[2]; \ + const uint32_t ne13 = src1->ne[3]; \ + \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t nb10 = src1->nb[0]; \ + const uint32_t nb11 = src1->nb[1]; \ + const uint32_t nb12 = src1->nb[2]; \ + const uint32_t nb13 = src1->nb[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - struct htp_matmul_context * mmctx = data; \ - struct htp_ops_context * octx = mmctx->octx; \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; \ + htp_matmul_tensors_preamble; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); assert(ne12 % ne02 == 0); assert(ne13 % ne03 == 0); @@ -3388,7 +596,9 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { return; } - // block-tiling attempt + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); + const uint32_t blck_0 = 64; const uint32_t blck_1 = 64; @@ -3412,28 +622,606 @@ static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, iir0); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, iir0); } } } - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0_start); } -// src1 tensor is already in VTCM spad -static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { - htp_matmul_preamble; +#include "hmx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-tiled.h" +#include "hvx-mm-kernels-flat.h" + +// Specialized repacked matmul macros +#define MATMUL_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01 * ne02 * ne03; \ + const uint32_t src1_nrows = ne11 * ne12 * ne13; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); \ + \ + float * dst_ptr0 = &dst_row0[ct * 32]; \ + float * dst_ptr1 = &dst_row1[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0, dst_ptr1, w_tile, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); \ + float * dst_ptr = &dst_row[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +#define MATVEC_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X1) \ +static void hvx_mv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + htp_matmul_preamble; \ + \ + const uint32_t src0_nrows = ne01; \ + \ + const uint32_t src0_start_row = src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); \ + \ + if (src0_start_row >= src0_end_row) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const size_t dst_row_size = nb1; \ + const size_t src1_row_size = nb11; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; \ + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * src1_data = mmctx->vtcm_src1; \ + \ + float * tmp = (float *) vtcm_dst_ptr; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src1_col = (const uint8_t *) src1_data; \ + float * restrict dst_col = (float *) dst->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; \ + \ + float * dst_ptr = &tmp[ct * 32 - src0_start_row]; \ + int valid_rows = (int)ne0 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + DOT_2X1(ne10, dst_ptr, w_tile, src1_col, valid_rows); \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + \ + int copy_cnt = (int)MIN(src0_end_row, ne0) - (int)src0_start_row; \ + if (copy_cnt > 0) { \ + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, copy_cnt); \ + } \ +} + +#define MATMUL_QKV_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_qkv_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wk */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* x */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wv */ \ + const struct htp_tensor * restrict src3 = octx->src[3]; /* Wq */ \ + const struct htp_tensor * restrict dst_k = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_v = octx->dsts[1]; \ + const struct htp_tensor * restrict dst_q = octx->dsts[2]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_k_row_size = dst_k->nb[1]; /* K and V share output width */ \ + const size_t dst_q_row_size = dst_q->nb[1]; /* Q may be wider (GQA) */ \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + /* 1. Process K and V together */ \ + const uint32_t src0_nrows_kv = src0->ne[1] * src0->ne[2] * src0->ne[3]; /* src0 is Wk */ \ + uint32_t src0_nrows_per_thread_kv = (src0_nrows_kv + nth - 1) / nth; \ + src0_nrows_per_thread_kv = hex_round_up(src0_nrows_per_thread_kv, 32); \ + \ + const uint32_t start_row_kv = src0_nrows_per_thread_kv * ith; \ + const uint32_t end_row_kv = MIN(start_row_kv + src0_nrows_per_thread_kv, src0_nrows_kv); \ + \ + if (start_row_kv < end_row_kv) { \ + uint32_t ct_start_kv = start_row_kv / 32; \ + uint32_t ct_end_kv = (end_row_kv + 31) / 32; \ + \ + uint32_t push_ct = ct_start_kv; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_kv; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_kv; ct < ct_end_kv; ct++) { \ + const uint8_t * w_tile_k = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_v = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src0->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_k = &dst_row0_k[ct * 32]; \ + float * dst_ptr1_k = &dst_row1_k[ct * 32]; \ + \ + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); \ + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); \ + float * dst_ptr0_v = &dst_row0_v[ct * 32]; \ + float * dst_ptr1_v = &dst_row1_v[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_k, dst_ptr1_k, w_tile_k, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_v, dst_ptr1_v, w_tile_v, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_k = &dst_row_k[ct * 32]; \ + \ + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); \ + float * dst_ptr_v = &dst_row_v[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_k, w_tile_k, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_v, w_tile_v, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ith); \ + \ + if (push_ct < ct_end_kv) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_k, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_v, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ + \ + /* 2. Process Q separately */ \ + const uint32_t src0_nrows_q = src3->ne[1] * src3->ne[2] * src3->ne[3]; /* src3 is Wq */ \ + uint32_t src0_nrows_per_thread_q = (src0_nrows_q + nth - 1) / nth; \ + src0_nrows_per_thread_q = hex_round_up(src0_nrows_per_thread_q, 32); \ + \ + const uint32_t start_row_q = src0_nrows_per_thread_q * ith; \ + const uint32_t end_row_q = MIN(start_row_q + src0_nrows_per_thread_q, src0_nrows_q); \ + \ + if (start_row_q < end_row_q) { \ + uint32_t ct_start_q = start_row_q / 32; \ + uint32_t ct_end_q = (end_row_q + 31) / 32; \ + \ + uint32_t push_ct = ct_start_q; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end_q; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + d * tile_row_transfer_size_aligned, \ + src3_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start_q; ct < ct_end_q; ct++) { \ + const uint8_t * w_tile_q = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)src3->ne[1] - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); \ + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); \ + float * dst_ptr0_q = &dst_row0_q[ct * 32]; \ + float * dst_ptr1_q = &dst_row1_q[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_q, dst_ptr1_q, w_tile_q, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); \ + float * dst_ptr_q = &dst_row_q[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_q, w_tile_q, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end_q) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_q, src3_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ + } \ +} + +#define MATMUL_FFN_2D_REPACKED_IMPL(SUFFIX, TILE_SIZE, DOT_2X2, DOT_2X1) \ +static void hvx_mm_ffn_2d_repacked_##SUFFIX(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + \ + const struct htp_tensor * restrict src0 = octx->src[0]; /* Wgate */ \ + const struct htp_tensor * restrict src1 = octx->src[1]; /* y */ \ + const struct htp_tensor * restrict src2 = octx->src[2]; /* Wup */ \ + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; \ + const struct htp_tensor * restrict dst_up = octx->dsts[1]; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne10 = src1->ne[0]; \ + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; \ + \ + const size_t dst_row_size = dst_gate->nb[1]; \ + const size_t src1_stride = mmctx->vtcm_src1_stride; \ + \ + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; \ + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; \ + uint8_t * restrict src1_data = mmctx->vtcm_src1; \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + \ + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; \ + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; \ + \ + const uint32_t tile_size = TILE_SIZE; \ + const uint32_t aligned_tile_size = hex_align_up(tile_size, 128); \ + \ + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; \ + const uint32_t n_prefetch = kparams->n_prefetch; \ + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); \ + \ + uint32_t n_k_tiles_w = ne00 / 32; \ + uint32_t n_k_tiles_a = ne10 / 32; \ + uint32_t tile_row_stride = n_k_tiles_w * tile_size; \ + uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; \ + dma_queue * dma_queue = octx->ctx->dma[ith]; \ + \ + const uint32_t src0_nrows = ne01 * src0->ne[2] * src0->ne[3]; \ + const uint32_t src0_start_row = mmctx->src0_nrows_per_thread * ith; \ + const uint32_t src0_end_row = MIN(src0_start_row + mmctx->src0_nrows_per_thread, src0_nrows); \ + \ + uint32_t ct_start = src0_start_row / 32; \ + uint32_t ct_end = (src0_end_row + 31) / 32; \ + \ + uint32_t push_ct = ct_start; \ + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, \ + src0_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + d * tile_row_transfer_size_aligned, \ + src2_row + push_ct * tile_row_stride), aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + } \ + \ + for (uint32_t ct = ct_start; ct < ct_end; ct++) { \ + const uint8_t * w_tile_gate = dma_queue_pop(dma_queue).dst; \ + const uint8_t * w_tile_up = dma_queue_pop(dma_queue).dst; \ + \ + int valid_rows = (int)ne01 - (int)(ct * 32); \ + valid_rows = MIN(32, MAX(0, valid_rows)); \ + \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + uint32_t ir1 = 0; \ + for (; ir1 + 1 < src1_nrows; ir1 += 2) { \ + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); \ + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); \ + \ + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_gate = &dst_row0_gate[ct * 32]; \ + float * dst_ptr1_gate = &dst_row1_gate[ct * 32]; \ + \ + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); \ + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); \ + float * dst_ptr0_up = &dst_row0_up[ct * 32]; \ + float * dst_ptr1_up = &dst_row1_up[ct * 32]; \ + \ + DOT_2X2(ne10, dst_ptr0_gate, dst_ptr1_gate, w_tile_gate, src1_col0, src1_col1, valid_rows); \ + DOT_2X2(ne10, dst_ptr0_up, dst_ptr1_up, w_tile_up, src1_col0, src1_col1, valid_rows); \ + } \ + \ + for (; ir1 < src1_nrows; ++ir1) { \ + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); \ + \ + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); \ + float * dst_ptr_gate = &dst_row_gate[ct * 32]; \ + \ + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); \ + float * dst_ptr_up = &dst_row_up[ct * 32]; \ + \ + DOT_2X1(ne10, dst_ptr_gate, w_tile_gate, src1_col, valid_rows); \ + DOT_2X1(ne10, dst_ptr_up, w_tile_up, src1_col, valid_rows); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); \ + \ + if (push_ct < ct_end) { \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_gate, src0_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile_up, src2_row + push_ct * tile_row_stride), \ + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); \ + push_ct++; \ + } \ + } \ +} + +MATMUL_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +#define QUANTIZE_IMPL(name, log_name, kernel_fn, dst_row_size_expr) \ +static void name(unsigned int nth, unsigned int ith, void * data) { \ + struct htp_mm_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + const struct htp_tensor * src = octx->src[1]; \ + const uint32_t ne0 = src->ne[0]; \ + const uint32_t ne1 = src->ne[1]; \ + const uint32_t ne2 = src->ne[2]; \ + const uint32_t ne3 = src->ne[3]; \ + const uint32_t nrows = ne1 * ne2 * ne3; \ + const uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; \ + \ + const uint32_t ir_first = nrows_per_thread * ith; \ + if (ir_first >= nrows) { \ + return; \ + } \ + \ + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ + \ + uint8_t * restrict dst = mmctx->vtcm_src1; \ + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); \ + const size_t src_row_size = src->nb[1]; \ + const size_t dst_row_size = (dst_row_size_expr); \ + const uint8_t * restrict src_data = (const uint8_t *) src->data + (src_row_size * ir_first); \ + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); \ + uint8_t * restrict tmp_data = (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith); \ + kernel_fn(src_data, dst_data, tmp_data, ne0, ir_last - ir_first, src_row_size, dst_row_size); \ + \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); \ +} + +QUANTIZE_IMPL(quantize_f32_q8_0_tiled, "quantize-f32-q8_0_tiled", quantize_f32_q8_0_tiled_kernel, htp_mm_q8_0_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_tiled, "quantize-f32-q8_1_tiled", quantize_f32_q8_1_tiled_kernel, htp_mm_q8_1_tiled_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_0_flat, "quantize-f32-q8_0_flat", quantize_f32_q8_0_flat_kernel, htp_mm_q8_0_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_q8_1_flat, "quantize-f32-q8_1_flat", quantize_f32_q8_1_flat_kernel, htp_mm_q8_1_flat_row_size(ne0)) +QUANTIZE_IMPL(quantize_f32_f32_flat, "quantize-f32-f32", quantize_f32_f32_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f32_f16_flat, "quantize-f32-f16", quantize_f32_f16_flat_kernel, mmctx->vtcm_src1_stride) +QUANTIZE_IMPL(quantize_f16_f16_flat, "quantize-f16-f16", quantize_f16_f16_flat_kernel, mmctx->vtcm_src1_stride) + +static void quantize_f32_q8_0_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_0_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_0_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); +} + +static void quantize_f32_q8_1_tiled_block(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); + + const struct htp_tensor * src = octx->src[1]; + + quantize_f32_q8_1_tiled_block_kernel( + (const float *) src->data, + mmctx->vtcm_src1, + (uint8_t *) mmctx->vtcm_src0 + (mmctx->vtcm_src0_size_per_thread * ith), + src->ne[0], + mmctx->quant_ib_first[ith], + mmctx->quant_ib_last[ith], + src->nb[1], + htp_mm_q8_1_tiled_row_size(src->ne[0]), + mmctx->quant_r[ith], + mmctx->quant_c[ith] + ); + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, mmctx->quant_ib_first[ith]); +} + +MATVEC_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x1) + +MATVEC_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x1) +MATVEC_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x1) +MATVEC_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x1) +MATVEC_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x1) +MATVEC_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x1) + + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_QKV_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_QKV_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0, 576, tiled_vec_dot_q4_0_32x2, tiled_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1, 640, tiled_vec_dot_q4_1_32x2, tiled_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0, 1088, tiled_vec_dot_q8_0_32x2, tiled_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl, 576, tiled_vec_dot_iq4nl_32x2, tiled_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4, 544, tiled_vec_dot_mxfp4_32x2, tiled_vec_dot_mxfp4_32x1) + +MATMUL_FFN_2D_REPACKED_IMPL(q4_0_flat, 576, flat_vec_dot_q4_0_32x2, flat_vec_dot_q4_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q4_1_flat, 640, flat_vec_dot_q4_1_32x2, flat_vec_dot_q4_1_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(q8_0_flat, 1088, flat_vec_dot_q8_0_32x2, flat_vec_dot_q8_0_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(iq4nl_flat, 576, flat_vec_dot_iq4nl_32x2, flat_vec_dot_iq4nl_32x1) +MATMUL_FFN_2D_REPACKED_IMPL(mxfp4_flat, 544, flat_vec_dot_mxfp4_32x2, flat_vec_dot_mxfp4_32x1) + +static void hvx_mm_2d(unsigned int nth, unsigned int ith, void * data) { + htp_matmul_preamble; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows @@ -3447,34 +1235,31 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; - - volatile uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; - // Prefill spad with src0 rows + // Prefill vtcm with src0 rows #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { + if (is0 >= (int)n_prefetch) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } // Process src0 rows @@ -3482,7 +1267,6 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Process src1 columns in pairs (2ร—2 tiling) uint32_t ir1 = 0; for (; ir1 + 1 < src1_nrows; ir1 += 2) { @@ -3499,24 +1283,23 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; + // Prefetch next (n + vtcm_nrows) row + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } } // Process the last row (if any) if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; - const int is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); @@ -3528,19 +1311,10 @@ static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { } htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mv_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const uint32_t src0_nrows = ne01; @@ -3552,164 +1326,101 @@ static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; const size_t src1_row_size = nb11; - const size_t src0_stride = src0_spad->stride; - const size_t src1_stride = src1_spad->stride; + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * vtcm_dst_ptr = mmctx->vtcm_dst + mmctx->vtcm_dst_size_per_thread * ith; + uint8_t * vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * src1_data = mmctx->vtcm_src1; - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - float * tmp = (float *) spad_dst; + float * tmp = (float *) vtcm_dst_ptr; const uint8_t * restrict src0_row = (const uint8_t *) src0->data; const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_col = (float *) dst->data; - if (mmctx->vec_dot_4x1 != NULL) { - const uint32_t src0_end_row_x4 = src0_start_row + ((src0_end_row - src0_start_row) & ~3U); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - // Prefill spad with 4x src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 4); + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + // Prefill vtcm with 2x src0 rows + #pragma unroll(2) + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint32_t is0 = (ir0 - src0_start_row); + if (is0 >= n_prefetch) { + break; } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x4; ir0 += 4) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_4x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, ss0 + 2 * src0_stride, ss0 + 3 * src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + // Process src0 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x4) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 4); - } - } - - // Process leftovers - uint32_t ir0 = src0_end_row_x4; - if (ir0 + 2 <= src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 2; - } - if (ir0 < src0_end_row) { - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - ir0 += 1; - } - } else { - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); - - // Prefill spad with 2x src0 rows - #pragma unroll(2) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - - // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), - src0_stride, src0_row_size, 2); - } - } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - const uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), - src0_stride, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + // Prefetch next (n + vtcm_nrows) row + const uint32_t pr0 = (ir0 + n_prefetch); + const uint32_t is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); } } + // Process the last row (if any) + if (src0_end_row != src0_end_row_x2) { + const uint32_t ir0 = src0_end_row_x2; + const uint32_t is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + } + hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id) * ids->ne[0] * ids->ne[1] + (i1)] -struct mmid_row_mapping { - uint32_t i1; - uint32_t i2; -}; - -// src1 tensor is already in VTCM spad -static void matmul_id(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mm_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01; // src0 rows per expert - const uint32_t src1_nrows = ne11; - + const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src1_nrows = ne11; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t n_ids = ids->ne[0]; // n_expert_used const uint32_t n_as = ne02; // n_expert @@ -3717,807 +1428,195 @@ static void matmul_id(unsigned int nth, unsigned int ith, void * data) { const struct mmid_row_mapping * matrix_rows = mmctx->matrix_rows; const size_t dst_row_size = nb1; - const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + const size_t src1_stride = mmctx->vtcm_src1_stride; - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { const int32_t cne1 = matrix_row_counts[cur_a]; - if (cne1 == 0) { continue; } - if (mmctx->hmx_eligible) { - continue; + const uint8_t * src0_row = (const uint8_t *) src0->data + cur_a * nb02; + + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); } - const uint8_t * src0_row = (const uint8_t *) src0->data + (0 + cur_a * nb02 + 0); + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); - } + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); for (uint32_t cid = 0; cid < cne1; ++cid) { struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); const int rm1 = row_mapping.i1; // expert idx const int rm2 = row_mapping.i2; // token idx - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); + const uint32_t ir1 = fastmodulo(rm1, ne11, &mmctx->mm_div_ne11); // src1 row idx + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_stride); + float * restrict dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; } } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - for (uint32_t cid = 0; cid < cne1; ++cid) { - struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, cid); - const int rm1 = row_mapping.i1; // expert idx - const int rm2 = row_mapping.i2; // token idx - - const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); - float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - } - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// src1 tensor is already in VTCM spad -static void matvec_id(unsigned int nth, unsigned int ith, void * data) { +static void hvx_mv_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; const struct htp_tensor * restrict ids = octx->src[2]; - struct htp_spad * restrict src2_spad = &octx->src2_spad; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - const uint32_t src0_nrows = ne01; // src0 rows per expert + const uint32_t src0_nrows = ne01; // src0 rows per expert const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); // no work for this thread if (src0_start_row >= src0_end_row) { return; } + struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + assert(ne13 % ne03 == 0); const size_t dst_row_size = nb1; - const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); - - const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + const size_t src1_row_size = htp_mm_q8_0_tiled_row_size(ne10); const uint32_t n_aids = src2->ne[0]; // num activated experts const uint32_t n_ids = ne02; // num experts - // Per-thread VTCM scratchpads for all tensors - // Note that the entire src1 tensor is already in VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; - uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; - uint8_t * restrict src1_data = src1_spad->data; + // Per-thread VTCMs for all tensors + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; for (uint32_t ie1 = 0; ie1 < n_aids; ++ie1) { // for each expert - const uint32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); - assert(eid < n_ids); + const int32_t eid = *(const int32_t *) ((const uint8_t *) src2->data + ie1 * src2->nb[0]); + if (eid < 0) { + continue; + } + assert(eid < (int32_t) n_ids); const uint8_t * restrict src0_row = (const uint8_t *) src0->data + eid * nb02; const uint8_t * restrict src1_col = (const uint8_t *) src1_data; float * restrict dst_row = (float *) (dst->data + ie1 * nb1); - // Prefill spad with src0 rows - #pragma unroll(4) - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const int is0 = (ir0 - src0_start_row); - if (is0 >= MM_SPAD_SRC0_NROWS) { - break; - } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + const uint32_t tile_size = htp_mm_get_weight_tile_size(src0->type); + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(src0->type); + const uint32_t n_k_tiles_w = ne00 / 32; + const uint32_t n_k_tiles_a = ne10 / 32; + const uint32_t tile_row_stride = n_k_tiles_w * tile_size; + const uint32_t tile_row_transfer_size_aligned = n_k_tiles_a * aligned_tile_size; + + const uint32_t ct_start = src0_start_row / 32; + const uint32_t ct_end = (src0_end_row + 31) / 32; + + uint32_t push_ct = ct_start; + for (uint32_t d = 0; d < n_prefetch && push_ct < ct_end; d++, push_ct++) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + d * tile_row_transfer_size_aligned, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); } - // Process src0 rows - for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); + for (uint32_t ct = ct_start; ct < ct_end; ct++) { + const uint8_t * w_tile = dma_queue_pop(dma_queue).dst; - // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + int valid_rows = (int)ne01 - (int)(ct * 32); + valid_rows = MIN(32, MAX(0, valid_rows)); + + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ct); + mmctx->vec_dot_32x1(ne10, &dst_row[ct * 32], w_tile, src1_col, valid_rows); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ct); + + if (push_ct < ct_end) { + dma_queue_push(dma_queue, dma_make_ptr((uint8_t *)w_tile, src0_row + push_ct * tile_row_stride), + aligned_tile_size, tile_size, tile_size, n_k_tiles_a); + push_ct++; } } - - // Process the last row (if any) - if (src0_end_row != src0_end_row_x2) { - uint32_t ir0 = src0_end_row_x2; - const uint32_t is0 = (ir0 - src0_start_row) % MM_SPAD_SRC0_NROWS; - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); - const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_COMP, ir0); - } } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, - ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], - dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** dynamic quant - -static inline void quantize_block_f32_q8_1x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; - - // --- Sum calculation --- - const HVX_Vector ones = Q6_Vb_vsplat_R(1); - HVX_Vector v_sums = Q6_Vw_vrmpy_VbVb(vx_i8, ones); // sum every 4 consecutive elements - // Sum 8 elements: - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 4)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 8)); - v_sums = Q6_Vw_vadd_VwVw(v_sums, Q6_V_vror_VR(v_sums, 16)); - - // Copy to stack to extract sums and vmaxes - float vmax0[32] __attribute__((aligned(128))); - float vmax1[32] __attribute__((aligned(128))); - float vmax2[32] __attribute__((aligned(128))); - float vmax3[32] __attribute__((aligned(128))); - int32_t sums[32] __attribute__((aligned(128))); - - hvx_vec_store_u(vmax0, 128, vmax0_sf); - hvx_vec_store_u(vmax1, 128, vmax1_sf); - hvx_vec_store_u(vmax2, 128, vmax2_sf); - hvx_vec_store_u(vmax3, 128, vmax3_sf); - hvx_vec_store_u(sums, 128, v_sums); - - float d0 = vmax0[0] / 127.0f; - float d1 = vmax1[0] / 127.0f; - float d2 = vmax2[0] / 127.0f; - float d3 = vmax3[0] / 127.0f; - - __fp16 * y_d_half = (__fp16 *) y_d; - y_d_half[0] = d0; - y_d_half[1] = (float) sums[0] * d0; - y_d_half[2] = d1; - y_d_half[3] = (float) sums[8] * d1; - y_d_half[4] = d2; - y_d_half[5] = (float) sums[16] * d2; - y_d_half[6] = d3; - y_d_half[7] = (float) sums[24] * d3; -} - -static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - HVX_Vector zero = Q6_V_vzero(); - - // Use reduce max fp32 to find max(abs(e)) first - HVX_Vector vmax0_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[0])); - HVX_Vector vmax1_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[1])); - HVX_Vector vmax2_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[2])); - HVX_Vector vmax3_sf = hvx_vec_reduce_max_f32(hvx_vec_abs_f32(vx[3])); - // Load and convert into QF32 - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes - - // Combine and convert to fp16 - HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); - HVX_Vector vmax23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax3_qf, vmax2_qf))); - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - hvx_vec_store_u(y_d + 0, 2, vd01_hf); - HVX_Vector rotated_vd_hf = Q6_V_vror_VR(vd01_hf, 64); - hvx_vec_store_u(y_d + 2, 2, rotated_vd_hf); - - hvx_vec_store_u(y_d + 4, 2, vd23_hf); - rotated_vd_hf = Q6_V_vror_VR(vd23_hf, 64); - hvx_vec_store_u(y_d + 6, 2, rotated_vd_hf); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes - HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes - - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); - HVX_Vector vd23_hf = Q6_Vhf_equals_Vqf16(vd23_qf16); - - hvx_vec_store_u(y_d + 0, 4, vd01_hf); - hvx_vec_store_u(y_d + 4, 4, vd23_hf); - - // Divide input by the scale - HVX_Vector vd01_inv_hf = hvx_vec_inverse_f16(vd01_hf); - HVX_Vector vd23_inv_hf = hvx_vec_inverse_f16(vd23_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd01_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd23_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { - assert((unsigned long) x % 128 == 0); - assert((unsigned long) y_q % 128 == 0); - - HVX_Vector * vx = (HVX_Vector *) x; - - // Load and convert into QF32 - HVX_Vector zero = Q6_V_vzero(); - HVX_Vector vx0_qf = Q6_Vqf32_vsub_VsfVsf(vx[0], zero); // 32 elements - HVX_Vector vx1_qf = Q6_Vqf32_vsub_VsfVsf(vx[1], zero); // 32 elements - HVX_Vector vx2_qf = Q6_Vqf32_vsub_VsfVsf(vx[2], zero); // 32 elements - HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements - - // Convert into fp16 - HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); - HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - - // Compute max and scale - HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes - - HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 - HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); - - *(HVX_UVector *) y_d = vd_hf; - - // Divide input by the scale - HVX_Vector vd_inv_hf = hvx_vec_inverse_f16(vd_hf); - vx01_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx01_hf, vd_inv_hf)); - vx23_hf = Q6_Vhf_equals_Vqf16(Q6_Vqf16_vmpy_VhfVhf(vx23_hf, vd_inv_hf)); - - // Convert to int8 - HVX_Vector vx01_i16 = hvx_vec_i16_from_hf_rnd_sat(vx01_hf); - HVX_Vector vx23_i16 = hvx_vec_i16_from_hf_rnd_sat(vx23_hf); - HVX_Vector vx_i8 = Q6_Vb_vpack_VhVh_sat(vx23_i16, vx01_i16); - - *(HVX_Vector *) y_q = vx_i8; -} - -// Overrides input x -static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; - - const uint32_t qrow_size = k; // int8 - - const uint32_t dblk_size = 8 * 2; // 8x __fp16 - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 - - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales - - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; - - for (uint32_t i = 0; i < nb; i++) { -#if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_f32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_f32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_f32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); -#else -#error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" -#endif - } - - // now copy the scales into final location - hvx_copy_f16_ua(y_d, t_d, nb * 8); -} - -static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8x4x2_row_size(ne0); - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); - - // FARF(HIGH, "quantize-q8x4-row: %u\n", i); - quantize_row_f32_q8x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-q8x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_row_f32_q8_1x4x2(float * restrict x, uint8_t * restrict y, uint32_t k) { - assert(k % 32 == 0); - const uint32_t qk = QK_Q8_0x4x2; - const uint32_t nb = (k + qk - 1) / qk; - - const uint32_t qrow_size = k; // int8 - - const uint32_t dblk_size = 8 * 4; // 8x (d, s) __fp16 = 32 bytes - const uint32_t qblk_size = QK_Q8_0x4x2; // int8 - - uint8_t * restrict y_q = (y + 0); // quants first - uint8_t * restrict y_d = (y + qrow_size); // then scales/sums - - // Temp scales override input since we're working off of the aligned temp buffer in VTCM - uint8_t * restrict t_d = (uint8_t *) x; - - for (uint32_t i = 0; i < nb; i++) { - quantize_block_f32_q8_1x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); - quantize_block_f32_q8_1x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); - } - - // now copy the scales/sums into final location - hvx_copy_f16_ua(y_d, t_d, nb * 16); -} - -static void quantize_f32_q8_1x4x2(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - struct htp_spad * spad = &octx->src0_spad; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = src->nb[1]; - const size_t dst_row_size = q8_1x4x2_row_size(ne0); - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_row_size * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_row_size * ir_first); - uint8_t * restrict tmp_data = (uint8_t *) spad->data + (spad->size_per_thread * ith); - - const size_t src_row_size_padded = hex_round_up(src_row_size, QK_Q8_0x4x2 * sizeof(float)); - memset(tmp_data, 0, src_row_size_padded); // zero-out temp row data for padding - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_row_size, 2); - hvx_copy_f32_aa(tmp_data, src_data, ne0); - - quantize_row_f32_q8_1x4x2((float *) tmp_data, dst_data, ne0); - dst_data += dst_row_size; - src_data += src_row_size; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-q8_1x4: %u/%u : n-rows %u (%u:%u) row-size %u -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_f32_f32(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f32_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-f32: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_f32_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f32-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - -// TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { - struct htp_matmul_context * mmctx = data; - struct htp_ops_context * octx = mmctx->octx; - struct htp_thread_trace * tr = octx->ctx ? &octx->ctx->trace[ith] : NULL; - - const struct htp_tensor * src = octx->src[1]; - uint8_t * restrict dst = octx->src1_spad.data; - uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; - uint32_t dst_stride = octx->src1_spad.stride; - - uint64_t t1 = HAP_perf_get_qtimer_count(); - - const uint32_t ne0 = src->ne[0]; - const uint32_t ne1 = src->ne[1]; - const uint32_t ne2 = src->ne[2]; - const uint32_t ne3 = src->ne[3]; - - const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows - - const uint32_t ir_first = nrows_per_thread * ith; // first row - htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); - const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row - - const size_t src_row_size = ne0 * sizeof(float); - const size_t src_stride = src->nb[1]; - - uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); - uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); - - for (uint32_t i = ir_first; i < ir_last; ++i) { - hex_l2fetch(src_data, src_row_size, src_stride, 2); - hvx_copy_f16_au(dst_data, src_data, ne0); - - dst_data += dst_stride; - src_data += src_stride; - } - - uint64_t t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "quantize-f16-f16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, - ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); - htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, ir_first); -} - - -static inline bool htp_is_permuted(const struct htp_tensor * t) { - return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; -} - -static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { +static int hvx_mm_init_vec_dot(struct htp_mm_context * mmctx, enum htp_data_type type) { switch (type) { case HTP_TYPE_Q4_0: - mmctx->type = "q4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4x4x2_q8x4x2_4x1; + mmctx->type = "q4_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_0_32x1; return 0; case HTP_TYPE_Q4_1: - mmctx->type = "q4_1x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q4_1x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q4_1x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q4_1x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q4_1x4x2_q8x4x2_4x1; + mmctx->type = "q4_1_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q4_1_32x1; return 0; case HTP_TYPE_Q8_0: - mmctx->type = "q8x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_q8x4x2_q8x4x2_4x1; + mmctx->type = "q8_0_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_q8_0_32x1; return 0; case HTP_TYPE_IQ4_NL: - mmctx->type = "iq4nlx4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_iq4nlx4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_iq4nlx4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_iq4nlx4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_iq4nlx4x2_q8x4x2_4x1; + mmctx->type = "iq4nl_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_iq4nl_32x1; return 0; case HTP_TYPE_MXFP4: - mmctx->type = "mxfp4x4x2-f32"; - mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; - mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; - mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; - mmctx->vec_dot_4x1 = vec_dot_mxfp4x4x2_q8x4x2_4x1; + mmctx->type = "mxfp4_tiled-f32"; + mmctx->vec_dot_32x1 = tiled_vec_dot_mxfp4_32x1; return 0; default: return -1; } } -static void htp_mminit_spad(struct htp_ops_context * octx, - size_t dst_row_size, - size_t src0_row_size_padded, - size_t src1_row_size, - uint32_t src1_nrows, - size_t src2_spad_size_per_thread) { - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - if (src2_spad_size_per_thread > 0) { - octx->src2_spad.size_per_thread = src2_spad_size_per_thread; - octx->src2_spad.size = octx->src2_spad.size_per_thread; - } - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; -} - -static int op_matmul_hvx(struct htp_ops_context * octx) { +static int hvx_mm_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; mmctx->octx = octx; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + // Compute src0_nrows_per_thread mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -4527,178 +1626,213 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; + worker_callback_t matmul_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows > 1) { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mm_2d; + } + } else { + if (is_repacked) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + matmul_job_func = hvx_mv_2d; + } + } bool need_quant = true; - if (src0->type == HTP_TYPE_F16) { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; + switch (kparams->kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16_flat : quantize_f16_f16_flat; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + src1_row_size = hex_round_up(ne10 * 2, 128); + break; - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + need_quant = false; + quant_job_func = NULL; + src1_row_size = nb11; + break; - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = hvx_mm_4d; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + quant_job_func = NULL; + break; - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: + quant_job_func = quantize_f32_f32_flat; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; + src1_row_size = hex_round_up(ne10 * 4, 128); + break; - src1_row_size = f16_src1_row_size; // row size post quantization + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + quant_job_func = NULL; + mmctx->type = "f32-f32"; + mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; + mmctx->mm_div_ne12_ne1 = kparams->div_ne12_ne1; + mmctx->mm_div_ne1 = kparams->div_ne1; + mmctx->mm_div_r2 = kparams->div_r2; + mmctx->mm_div_r3 = kparams->div_r3; + src1_row_size = nb11; + need_quant = false; + matmul_job_func = hvx_mm_4d; + break; - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - mmctx->type = "f16-f32"; - mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; - matmul_job_func = matmul_4d; + if (src1_nrows > 1) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } } else { - mmctx->type = "f16-f16"; - mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; - matmul_job_func = matmul_4d; + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + break; + } + + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: + default: + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } - src1_row_size = nb11; // original row size in DDR + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } else if (src0->type == HTP_TYPE_F32) { - // Try optimized f32-f32 path first (src1 in VTCM) - const size_t f32_src1_row_size = hex_round_up(ne10 * 4, 128); - const size_t f32_src1_spad_size = hex_round_up(f32_src1_row_size * src1_nrows, 256); - const size_t f32_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f32_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f32_total_size = f32_src1_spad_size + f32_src0_spad_size + f32_dst_spad_size; - - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(octx->src[0]) || htp_is_permuted(octx->src[1]); - - if (!is_batched && !is_permuted && f32_total_size <= octx->ctx->vtcm_size) { - // Optimized path - quant_job_func = quantize_f32_f32; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_aa_1x1; - mmctx->vec_dot_2x1 = vec_dot_f32_f32_aa_2x1; - mmctx->vec_dot_2x2 = vec_dot_f32_f32_aa_2x2; - - src1_row_size = f32_src1_row_size; - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to DDR / broadcasting - quant_job_func = NULL; - mmctx->type = "f32-f32"; - mmctx->vec_dot_1x1 = vec_dot_f32_f32_uu_1x1; - matmul_job_func = matmul_4d; - - src1_row_size = nb11; - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } else { - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { - return HTP_STATUS_NO_SUPPORT; - } - - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); - } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); - } - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + break; } - // VTCM scratchpads for all tensors - size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; + size_t src0_sz = 0, src1_sz = 0, dst_sz = 0; + if (kparams->vtcm_src0_size > 0 || kparams->vtcm_src1_size > 0 || kparams->vtcm_dst_size > 0) { + src0_sz = kparams->vtcm_src0_size; + src1_sz = kparams->vtcm_src1_size; + dst_sz = kparams->vtcm_dst_size; + } else { + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + htp_mm_hvx_get_vtcm_sizes( + kparams->kernel_type, src0->type, ne10, src1_nrows, octx->n_threads, + dst_row_size, src0_row_size, src1_row_size, n_prefetch, + &src0_sz, &src1_sz, &dst_sz + ); + } - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_F16_F16_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_F32_F32_VTCM || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW || + kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_BLOCK) { + mmctx->vtcm_src1_size_per_thread = src1_sz; + } else { + mmctx->vtcm_src1_size_per_thread = src1_sz / octx->n_threads; + } + + mmctx->vtcm_src0_size_per_thread = src0_sz / octx->n_threads; + mmctx->vtcm_dst_size_per_thread = dst_sz / octx->n_threads; + + size_t vtcm_size = kparams->vtcm_size > 0 ? (size_t)kparams->vtcm_size : (src1_sz + src0_sz + dst_sz); + + FARF(HIGH, "matmul-%s : src0-vtcm-size %zu src1-vtcm-size %zu dst-vtcm-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, dst_sz, vtcm_size); FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { + if (octx->ctx->vtcm_size < vtcm_size) { FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, - octx->ctx->vtcm_size, spad_size); + octx->ctx->vtcm_size, vtcm_size); return HTP_STATUS_VTCM_TOO_SMALL; } - // Place src1 spad first. We use it for dyn.quant and may reuse between ops - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; + octx->src1_spad.src = NULL; octx->src0_spad.src = NULL; octx->dst_spad.src = NULL; - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return HTP_STATUS_OK; - if (need_quant && !octx->src1_spad.src) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + if (need_quant) { mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; } const uint32_t n_matmul_jobs = octx->n_threads; @@ -4707,72 +1841,1209 @@ static int op_matmul_hvx(struct htp_ops_context * octx) { return HTP_STATUS_OK; } -int op_matmul(struct htp_ops_context * octx) { +static void hvx_mm_qkv_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_k_row_size = dst_k->nb[1]; // K and V share output width + const size_t dst_q_row_size = dst_q->nb[1]; // Q may be wider (GQA) + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; + const size_t src3_row_size = src3->nb[1]; + + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src3_stride = mmctx->vtcm_src3_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; + + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict vtcm_src3_ptr = mmctx->vtcm_src3 + mmctx->vtcm_src3_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + dma_queue * dma_queue = octx->ctx->dma[ith]; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; + const uint8_t * restrict src3_row = (const uint8_t *) src3->data; + + // Prefill spad with src0, src2, src3 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; + } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); + } + + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; + + // Process src1 columns in pairs (2ร—2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + + float * restrict dst_row0_k = (float *) (dst_k->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_k = (float *) (dst_k->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_k[ir0], &dst_row1_k[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + + float * restrict dst_row0_v = (float *) (dst_v->data + ((ir1+0) * dst_k_row_size)); + float * restrict dst_row1_v = (float *) (dst_v->data + ((ir1+1) * dst_k_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_v[ir0], &dst_row1_v[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); + + float * restrict dst_row0_q = (float *) (dst_q->data + ((ir1+0) * dst_q_row_size)); + float * restrict dst_row1_q = (float *) (dst_q->data + ((ir1+1) * dst_q_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_q[ir0], &dst_row1_q[ir0], ss3, ss3 + src3_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2ร—1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_k[ir0], ss0, ss0 + src0_stride, src1_col); + + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_v[ir0], ss2, ss2 + src2_stride, src1_col); + + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_q[ir0], ss3, ss3 + src3_stride, src1_col); + } + + // Prefetch next (n + vtcm_nrows) rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + pr0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 2); + } + } + + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src3_ptr + is0 * src3_stride, src3_row + ir0 * src3_row_size), + src3_stride, src3_row_size, src3_row_size, 1); + + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss3 = dma_queue_pop(dma_queue).dst; + + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_k = (float *) (dst_k->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_k[ir0], ss0, src1_col); + + float * restrict dst_row_v = (float *) (dst_v->data + (ir1 * dst_k_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_v[ir0], ss2, src1_col); + + float * restrict dst_row_q = (float *) (dst_q->data + (ir1 * dst_q_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_q[ir0], ss3, src1_col); + } + } +} + +static void hvx_mm_ffn_2d(unsigned int nth, unsigned int ith, void * data) { + struct htp_mm_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + + const uint32_t ne11 = src1->ne[1]; + const uint32_t ne12 = src1->ne[2]; + const uint32_t ne13 = src1->ne[3]; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src1_nrows = ne11 * ne12 * ne13; + + const uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t src0_end_row_x2 = src0_start_row + ((src0_end_row - src0_start_row) & ~1U); + + if (src0_start_row >= src0_end_row) { + return; + } + + const size_t dst_row_size = dst_gate->nb[1]; + const size_t src0_row_size = src0->nb[1]; + const size_t src2_row_size = src2->nb[1]; + + const size_t src0_stride = mmctx->vtcm_src0_stride; + const size_t src2_stride = mmctx->vtcm_src2_stride; + const size_t src1_stride = mmctx->vtcm_src1_stride; + + uint8_t * restrict vtcm_src0_ptr = mmctx->vtcm_src0 + mmctx->vtcm_src0_size_per_thread * ith; + uint8_t * restrict vtcm_src2_ptr = mmctx->vtcm_src2 + mmctx->vtcm_src2_size_per_thread * ith; + uint8_t * restrict src1_data = mmctx->vtcm_src1; + + dma_queue * dma_queue = octx->ctx->dma[ith]; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const uint32_t n_prefetch = kparams->n_prefetch; + assert(n_prefetch >= 2 && n_prefetch <= HTP_MM_MAX_PREFETCH && (n_prefetch & (n_prefetch - 1)) == 0); + const uint32_t prefetch_mask = n_prefetch - 1; + + const uint8_t * restrict src0_row = (const uint8_t *) src0->data; + const uint8_t * restrict src2_row = (const uint8_t *) src2->data; + + // Prefill spad with src0, src2 rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const int is0 = (ir0 - src0_start_row); + if (is0 >= (int)n_prefetch) { + break; + } + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } + + // Process rows + for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + + // Process src1 columns in pairs (2ร—2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + + float * restrict dst_row0_gate = (float *) (dst_gate->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_gate = (float *) (dst_gate->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_gate[ir0], &dst_row1_gate[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + + float * restrict dst_row0_up = (float *) (dst_up->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1_up = (float *) (dst_up->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0_up[ir0], &dst_row1_up[ir0], ss2, ss2 + src2_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2ร—1) + for (; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_gate[ir0], ss0, ss0 + src0_stride, src1_col); + + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_2x1(ne00, &dst_row_up[ir0], ss2, ss2 + src2_stride, src1_col); + } + + // Prefetch next rows + const int pr0 = (ir0 + n_prefetch); + const int is0 = (pr0 - src0_start_row) & prefetch_mask; + if (pr0 < src0_end_row_x2) { + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 2); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + pr0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 2); + } + } + + // Process last row (if any) + if (src0_end_row != src0_end_row_x2) { + uint32_t ir0 = src0_end_row_x2; + const int is0 = (ir0 - src0_start_row) & prefetch_mask; + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src0_ptr + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, src0_row_size, 1); + dma_queue_push(dma_queue, dma_make_ptr(vtcm_src2_ptr + is0 * src2_stride, src2_row + ir0 * src2_row_size), + src2_stride, src2_row_size, src2_row_size, 1); + + const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; + const uint8_t * ss2 = dma_queue_pop(dma_queue).dst; + + for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); + + float * restrict dst_row_gate = (float *) (dst_gate->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_gate[ir0], ss0, src1_col); + + float * restrict dst_row_up = (float *) (dst_up->data + (ir1 * dst_row_size)); + mmctx->vec_dot_1x1(ne00, &dst_row_up[ir0], ss2, src1_col); + } + } +} + +#define DEQUANTIZE_WORKER_LOOP_IMPL(SUFFIX) \ +static void dequantize_tiled_worker_loop_##SUFFIX(unsigned int n, unsigned int i, void *data) { \ + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; \ + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; \ + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { \ + int start = task_id * state->n_tiles_per_task; \ + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); \ + dequantize_tiled_weight_to_fp16_task_##SUFFIX(state, start, end); \ + } \ + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); \ +} + +DEQUANTIZE_WORKER_LOOP_IMPL(q4_0) +DEQUANTIZE_WORKER_LOOP_IMPL(q4_1) +DEQUANTIZE_WORKER_LOOP_IMPL(iq4_nl) +DEQUANTIZE_WORKER_LOOP_IMPL(mxfp4) +DEQUANTIZE_WORKER_LOOP_IMPL(q8_0) + +static void convert_f16_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + convert_f16_weight_to_fp16_tiles_task(state, start, end); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_W_DEQUANT, i); +} + +static void quantize_f32_worker_loop(unsigned int n, unsigned int i, void *data) { + tiled_dequantize_state_t *state = (tiled_dequantize_state_t *)data; + + struct htp_thread_trace * tr = state->traces ? &state->traces[i] : NULL; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); + + for (unsigned int task_id = i; task_id < (unsigned int)state->n_tasks; task_id += n) { + int start = task_id * state->n_tiles_per_task; + int end = hex_smin(start + state->n_tiles_per_task, state->n_tot_tiles); + quantize_f32_weight_to_fp16_tiles_task(state, start, end); + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_QUANT, i); +} + +static void transfer_output_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_task_state_t *st = (output_transfer_task_state_t *) data; + + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + float *dst = st->dst + chunk_idx * st->dst_stride; + transfer_output_chunk_fp16_to_fp32(dst, st->vtcm_src, chunk_idx, chunk_size, st->n_cols, st->dst_stride, st->dst_cols); + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, start_chunk_idx); +} + +static void transfer_activation_chunk_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_task_state_t *st = (activation_transfer_task_state_t *) data; + + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + + int start_chunk_idx = i * st->n_chunks_per_task; + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); + + for (unsigned int task_id = i; task_id < (unsigned int)st->n_tasks; task_id += n) { + int chunk_idx = task_id * st->n_chunks_per_task; + size_t chunk_size = hex_smin(st->n_tot_chunks - chunk_idx, st->n_chunks_per_task); + + __fp16 *dst = st->dst + chunk_idx * st->k_block; + const float *src = st->src + chunk_idx * st->k_stride; + + if (st->vtcm_f32_act) { + float *thread_f32_act = st->vtcm_f32_act + i * HTP_MM_DMA_ACT_MULTIPLIER * st->k_block; + transfer_activation_chunk_fp32_to_fp16_dma_pipelined( + st->ctx->dma[i], dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid, thread_f32_act + ); + } else { + transfer_activation_chunk_fp32_to_fp16(dst, src, chunk_size, st->k_block, st->k_stride, st->k_valid); + } + } + + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, start_chunk_idx); +} + +static void transfer_activation_chunk_gathered_worker_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->ne11, &st->ne11_div, st->nb11, st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_activation_chunk_gathered_worker_flat_fn(unsigned int n, unsigned int i, void *data) { + activation_transfer_gathered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + transfer_activation_chunk_fp32_to_fp16_gathered_flat( + st->dst, st->src, start_row, n_rows, st->k_block, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->nb12, st->cne1, st->k_valid); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_A_PREP, chunk_idx); + } +} + +static void transfer_output_chunk_scattered_worker_fn(unsigned int n, unsigned int i, void *data) { + output_transfer_scattered_task_state_t *st = data; + struct htp_thread_trace * tr = st->traces ? &st->traces[i] : NULL; + int chunk_idx = i; + int chunk_size = st->n_chunks_per_task; + int start_row = st->start_row + chunk_idx * chunk_size; + int n_rows = hex_smin(st->cne1 - start_row, chunk_size); + if (n_rows > 0) { + htp_trace_event_start(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + transfer_output_chunk_fp16_to_fp32_scattered( + st->dst, st->vtcm_src, start_row, n_rows, st->n_cols, + st->matrix_rows, st->cur_a, st->mapping_stride, + st->dst_nb1, st->dst_nb2, st->cne1); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HVX_O_PROC, chunk_idx); + } +} + +// --- HMX Dispatchers & Entry Points --- + +static void dequantize_tiled_weight_chunk_to_fp16_tiles( + struct htp_context *ctx, __fp16 *vtcm_dst, + const void *weight_src_ddr, + int n_cols, int k_block, + size_t row_stride, int weight_type, + int n_k_tiles, struct fastdiv_values n_k_tiles_div, + worker_callback_t dequant_worker_fn, int n_threads) { + + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_col_tiles = n_cols / HTP_MM_HMX_TILE_N_COLS; + size_t n_tot_tiles = n_col_tiles * n_k_tiles; + + size_t n_tiles_per_task = (n_threads == 1) ? n_tot_tiles : hmx_ceil_div(n_tot_tiles, n_threads); + + tiled_dequantize_state_t state; + state.n_tasks = (n_tot_tiles + n_tiles_per_task - 1) / n_tiles_per_task; + state.n_tot_tiles = n_tot_tiles; + state.n_tiles_per_task = n_tiles_per_task; + state.dst = vtcm_dst; + state.src = (const uint8_t *)weight_src_ddr; + state.n_cols = n_cols; + state.k_block = k_block; + state.row_stride = row_stride; + state.weight_type = weight_type; + state.n_k_tiles = n_k_tiles; + state.n_k_tiles_div = n_k_tiles_div; + state.traces = ctx->trace; + state.ctx = ctx; + + state.tile_size = htp_mm_get_weight_tile_size(weight_type); + state.aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); + + if (state.n_tasks == 1 || n_threads == 1) { + dequant_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, dequant_worker_fn, &state, n_tasks); + } +} + +static void transfer_output_chunk_threaded(struct htp_context *ctx, float *dst, const __fp16 *vtcm_src, + int n_rows, int n_cols, int dst_stride, int dst_cols, int n_threads) { + assert(n_cols % HTP_MM_HMX_TILE_N_COLS == 0); + + if (n_rows <= 0) return; + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : hmx_ceil_div(n_rows, n_threads); + n_chunks_per_task = hex_align_up(n_chunks_per_task, 2); + + int actual_threads = hmx_ceil_div(n_rows, n_chunks_per_task); + + output_transfer_task_state_t state; + state.n_tasks = actual_threads; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.vtcm_src = vtcm_src; + state.n_cols = n_cols; + state.dst_stride = dst_stride; + state.dst_cols = dst_cols; + state.traces = ctx->trace; + + if (actual_threads <= 1) { + transfer_output_chunk_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_worker_fn, &state, actual_threads); + } +} + +static void transfer_activation_chunk_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int n_rows, + int k_block, + int k_stride, + int n_threads, + int k_valid, + float *vtcm_f32_act) { + assert(k_block % HTP_MM_HMX_TILE_N_COLS == 0 && k_stride % HTP_MM_HMX_TILE_N_COLS == 0); + + size_t n_tot_chunks = n_rows; + size_t n_chunks_per_task = (n_threads == 1) ? n_tot_chunks : 32; // must be multiple of 32 to ensure correct destination address + + activation_transfer_task_state_t state; + state.n_tasks = (n_tot_chunks + n_chunks_per_task - 1) / n_chunks_per_task; + state.n_tot_chunks = n_tot_chunks; + state.n_chunks_per_task = n_chunks_per_task; + state.dst = dst; + state.src = src; + state.k_block = k_block; + state.k_stride = k_stride; + state.k_valid = k_valid; + state.traces = ctx->trace; + state.ctx = ctx; + state.vtcm_f32_act = vtcm_f32_act; + + if (state.n_tasks == 1 || n_threads == 1) { + transfer_activation_chunk_worker_fn(1, 0, &state); + } else { + int n_tasks = hex_smin((int) state.n_tasks, n_threads); + worker_pool_run_func(ctx->worker_pool, transfer_activation_chunk_worker_fn, &state, n_tasks); + } +} + +static int hmx_mm_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int act_stride, + int weight_stride, + int weight_type, + int k_valid, + int dst_stride, + int dst_cols, + int m_chunk, + int n_chunk, + int pipeline, + int n_threads, + int act_threads, + int tile_size, + int aligned_tile_size, + int vtcm_size) { + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; + + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + + const size_t act_f32_size = hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); + + const size_t weight_area_size = is_quant + ? hex_align_up((n_chunk_n_cols / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : hex_align_up(n_chunk_n_cols * row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size, scratch1_size, scratch2_size; + scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); // dequant buf 0 + scratch1_size = pipeline ? scratch0_size : 0; // dequant buf 1 + scratch2_size = pipeline ? output_area_size : 0; // output buf 1 + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight_raw[2] = { NULL, NULL }; + if (weight_area_size) { + if (pipeline) { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + vtcm_weight_raw[1] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } else { + vtcm_weight_raw[0] = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + } + } + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + float *vtcm_f32_act = (float *) vtcm_seq_alloc(&vtcm_ptr, act_f32_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + void *vtcm_scratch1 = scratch1_size ? vtcm_seq_alloc(&vtcm_ptr, scratch1_size) : NULL; + void *vtcm_scratch2 = scratch2_size ? vtcm_seq_alloc(&vtcm_ptr, scratch2_size) : NULL; + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-2d-precomputed: VTCM overflow: used %zu budget %zu, m %d k %d n %d mc %zu nc %zu", + vtcm_used, vtcm_budget, m, k, n, m_chunk_n_rows, n_chunk_n_cols); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "hmx-mm-2d-precomputed: standard : m %d k %d n %d wtype %d mc %zu nc %zu vtcm %zu/%zu", + m, k, n, weight_type, m_chunk_n_rows, n_chunk_n_cols, vtcm_used, vtcm_budget); + + int n_chunk_cnt = hmx_ceil_div(n, n_chunk_n_cols); + + if (pipeline) { + // --- Asynchronous Pipelined Loop --- + hmx_matmul_job_t job_slots[2]; // persistent double-buffered job descriptors + + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + void *vtcm_weight_bufs[2] = { vtcm_scratch0, vtcm_scratch1 }; + void *vtcm_output_bufs[2] = { vtcm_output, vtcm_scratch2 }; + + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); + + // Prologue: push A0 and optionally A1 (if n_chunk_cnt > 1) + const size_t n_cols_A0 = hex_smin(n - 0 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), aligned_tile_size, tile_size, tile_size, (n_cols_A0 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight), row_stride, weight_stride, row_stride, n_cols_A0); + } + + if (1 < n_chunk_cnt) { + const size_t n_cols_A1 = hex_smin(n - 1 * n_chunk_n_cols, n_chunk_n_cols); + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_A1 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[1], weight + n_chunk_n_cols * weight_stride), row_stride, weight_stride, row_stride, n_cols_A1); + } + } + + // pop A0 -> dequantize A0 -> submit C0 + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[0], vtcm_weight_raw[0], + n_cols_A0, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + hmx_matmul_job_init(&job_slots[0], (__fp16 *) vtcm_output_bufs[0], (__fp16 *) vtcm_f16_act, + (__fp16 *) vtcm_weight_bufs[0], vtcm_scales, + hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_A0, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[0])); + + // Main loop: pop/dequantize A_{i+1} -> push A_{i+2} -> submit C_{i+1} -> wait C_i and store D_i + for (int i = 0; i < n_chunk_cnt; ++i) { + const size_t nc = i * n_chunk_n_cols; + const size_t nc_p1 = nc + 1 * n_chunk_n_cols; + const size_t nc_p2 = nc + 2 * n_chunk_n_cols; + + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_cols_p1 = hex_smin(n - nc_p1, n_chunk_n_cols); + const size_t n_cols_p2 = hex_smin(n - nc_p2, n_chunk_n_cols); + + // 1. pop A_{i+1} and dequantize it (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + dma_queue_pop(ctx->dma[0]); + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_weight_bufs[(i + 1) % 2], vtcm_weight_raw[(i + 1) % 2], + n_cols_p1, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + } + + // 2. push A_{i+2} (if i+2 < n_chunk_cnt) + if (i + 2 < n_chunk_cnt) { + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols_p2 / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[(i + 2) % 2], weight + nc_p2 * weight_stride), row_stride, weight_stride, row_stride, n_cols_p2); + } + } + + // 3. submit C_{i+1} (if i+1 < n_chunk_cnt) + if (i + 1 < n_chunk_cnt) { + hmx_matmul_job_init(&job_slots[(i + 1) % 2], (__fp16 *) vtcm_output_bufs[(i + 1) % 2], + (__fp16 *) vtcm_f16_act, (__fp16 *) vtcm_weight_bufs[(i + 1) % 2], + vtcm_scales, hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS), + hmx_ceil_div(n_cols_p1, HTP_MM_HMX_TILE_N_COLS), k / HTP_MM_HMX_TILE_N_ROWS); + hmx_queue_push(ctx->hmx_queue, hmx_queue_make_desc(hmx_matmul_worker_fn, &job_slots[(i + 1) % 2])); + } + + // 4. wait C_i and store D_i (multi-thread HVX, parallel with C_{i+1}) + hmx_queue_pop(ctx->hmx_queue); + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output_bufs[i % 2], n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + hmx_queue_suspend(ctx->hmx_queue); + } else { + // --- Synchronous Un-pipelined loop (m <= 32 or fallback) --- + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + for (size_t mr = 0; mr < m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m - mr, m_chunk_n_rows); + + transfer_activation_chunk_threaded(ctx, vtcm_f16_act, activation + mr * act_stride, n_rows, k, act_stride, act_threads, k_valid, vtcm_f32_act); + + for (size_t nc = 0; nc < n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin(n - nc, n_chunk_n_cols); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); + + // A: Weight DMA (Synchronous) + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight_raw[0], weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + // B: Weight Dequantize (Threaded) + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight_raw[0], + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads); + + // C: HMX Compute (Synchronous) + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + + // D: Output Store + float *output_chunk = dst + (mr * dst_stride + nc); + int chunk_dst_cols = dst_cols - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output_chunk, vtcm_output, n_rows, n_cols, dst_stride, chunk_dst_cols, n_threads); + } + } + } + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + } + + return 0; +} + +static inline int hmx_mm_batch_r2(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne02 > 0 ? params->ne12 / params->ne02 : 1; +} + +static inline int hmx_mm_batch_r3(const hmx_mm_f16_f32_batched_params_t *params) { + return params->ne03 > 0 ? params->ne13 / params->ne03 : 1; +} + +static inline const __fp16 *hmx_mm_weight_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + const int r2 = hmx_mm_batch_r2(params); + const int r3 = hmx_mm_batch_r3(params); + return (const __fp16 *) ((const uint8_t *) params->weight + + (size_t) (dst_b2 / r2) * params->src0_nb2 + + (size_t) (dst_b3 / r3) * params->src0_nb3); +} + +static inline const float *hmx_mm_activation_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (const float *) ((const uint8_t *) params->activation + + (size_t) dst_b2 * params->src1_nb2 + + (size_t) dst_b3 * params->src1_nb3); +} + +static inline float *hmx_mm_dst_batch_ptr(const hmx_mm_f16_f32_batched_params_t *params, + int dst_b2, int dst_b3) { + return (float *) ((uint8_t *) params->dst + + (size_t) dst_b2 * params->dst_nb2 + + (size_t) dst_b3 * params->dst_nb3); +} + +static int hmx_mm_f16_f32_batched_simple(struct htp_context *ctx, + const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + int ret = 0; + for (int b3 = 0; b3 < params->ne13 && ret == 0; ++b3) { + for (int b2 = 0; b2 < params->ne12 && ret == 0; ++b2) { + ret = hmx_mm_2d_f32(ctx, hmx_mm_dst_batch_ptr(params, b2, b3), + hmx_mm_activation_batch_ptr(params, b2, b3), + (const uint8_t *)hmx_mm_weight_batch_ptr(params, b2, b3), + params->m, params->k, params->n, + params->act_stride, params->weight_stride * (int)sizeof(__fp16), + HTP_TYPE_F16, params->k, params->n, params->n, + m_chunk, n_chunk, pipeline, n_threads, act_threads, + 0, 0, vtcm_size); + } + } + return ret; +} + +static int hmx_mm_f16_f32_batched(struct htp_context *ctx, const hmx_mm_f16_f32_batched_params_t *params, + int m_chunk, int n_chunk, int pipeline, int n_threads, int act_threads, int vtcm_size) { + if (params->act_stride < params->k || params->weight_stride < params->k || params->dst_stride < params->n) { return -1; } + if (params->ne02 <= 0 || params->ne03 <= 0 || params->ne12 <= 0 || params->ne13 <= 0) { return -1; } + if (params->ne12 % params->ne02 != 0 || params->ne13 % params->ne03 != 0) { return -1; } + if (params->k % 32 != 0 || params->n % 32 != 0) { return -1; } + if (!hex_is_aligned(params->dst, VLEN) || !hex_is_aligned(params->activation, VLEN)) { return -1; } + + const int group_size = hmx_mm_batch_r2(params); + const size_t vtcm_budget = ctx->vtcm_size; + + // Check if the precomputed parameters are grouped or simple. + // If simple, or if group_size <= 1, we use simple fallback loop. + // Grouped path is only valid if group_size > 1 and it fits within VTCM budget. + bool run_grouped = (group_size > 1 && (size_t)vtcm_size <= vtcm_budget); + if (!run_grouped) { + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + const size_t vec_dot_size = params->k * sizeof(__fp16); + + const bool use_dma_activation = (params->act_stride > params->k); + const size_t f32_scratch_size = use_dma_activation + ? hex_align_up((size_t)act_threads * HTP_MM_DMA_ACT_MULTIPLIER * (size_t) params->k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + size_t m_chunk_n_rows = m_chunk; + size_t n_chunk_n_cols = n_chunk; + size_t vtcm_used = vtcm_size; + + const size_t act_head_stride = m_chunk_n_rows * (size_t) params->k; // fp16 elements between heads + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t activation_area_size = hex_align_up(group_size * m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size); + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, activation_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + void *vtcm_scratch1 = vtcm_seq_alloc(&vtcm_ptr, scratch_area_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + float *vtcm_f32_act = use_dma_activation ? (float *) vtcm_seq_alloc(&vtcm_ptr, f32_scratch_size) : NULL; + + if ((size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base) > vtcm_budget) { + FARF(HIGH, "%s: grouped layout overflowed VTCM, falling back to simple batched loop", __func__); + return hmx_mm_f16_f32_batched_simple(ctx, params, m_chunk, n_chunk, pipeline, n_threads, act_threads, vtcm_size); + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); // scale: 1.0, bias: 0.0 in FP16 + + FARF(HIGH, "%s: grouped path m=%d k=%d n=%d group=%d streams=%d mc=%zu nc=%zu vtcm=%zu/%zu", + __func__, params->m, params->k, params->n, group_size, params->ne13, + m_chunk_n_rows, n_chunk_n_cols, + (size_t) (vtcm_ptr - (uint8_t *) ctx->vtcm_base), vtcm_budget); + + const size_t fp16_row_bytes = (size_t) params->k * sizeof(__fp16); + const size_t weight_row_bytes = (size_t) params->weight_stride * sizeof(__fp16); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (int b3 = 0; b3 < params->ne13; ++b3) { + for (int b2_base = 0; b2_base < params->ne12; b2_base += group_size) { + const __fp16 *weight_group = hmx_mm_weight_batch_ptr(params, b2_base, b3); + + for (size_t mr = 0; mr < (size_t) params->m; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin((size_t) params->m - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div((int) n_rows, HTP_MM_HMX_TILE_N_ROWS); + + // Pre-load activations for all heads in the group (once per m_chunk). + // When the source is strided (permuted Q), use 2D DMA to gather + // contiguous rows into a VTCM scratch buffer first, then HVX + // converts from the contiguous VTCM buffer. This avoids L2 cache + // thrashing from HVX loads at large strides. + for (int g = 0; g < group_size; ++g) { + const float *activation_chunk = hmx_mm_activation_batch_ptr(params, b2_base + g, b3) + mr * params->act_stride; + __fp16 *vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + if (use_dma_activation) { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, vtcm_f32_act); + } else { + transfer_activation_chunk_threaded(ctx, vtcm_act_g, + activation_chunk, (int) n_rows, + params->k, params->act_stride, act_threads, params->k, NULL); + } + } + + void *buf_curr = vtcm_scratch0; + void *buf_next = vtcm_scratch1; + + { + const size_t n_cols_first = hex_smin((size_t) params->n, n_chunk_n_cols); + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_curr, weight_group), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_first); + } + + for (size_t nc = 0; nc < (size_t) params->n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) params->n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div((int) n_cols, HTP_MM_HMX_TILE_N_COLS); + + { + dma_queue_pop(ctx->dma[0]); + + const size_t nc_next = nc + n_chunk_n_cols; + if (nc_next < (size_t) params->n) { + const size_t n_cols_next = hex_smin((size_t) params->n - nc_next, n_chunk_n_cols); + const __fp16 *next_weight_chunk = weight_group + nc_next * params->weight_stride; + + dma_queue_push(ctx->dma[0], dma_make_ptr(buf_next, next_weight_chunk), + fp16_row_bytes, weight_row_bytes, fp16_row_bytes, n_cols_next); + } + + hmx_interleave_rows_to_tiles(vtcm_weight, (const __fp16 *) buf_curr, n_cols, params->k, params->k, 0, n_cols); + hex_swap_ptr(&buf_curr, &buf_next); + } + + // Reuse the interleaved weight for every q_head in this GQA group + for (int g = 0; g < group_size; ++g) { + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, g); + { + const __fp16 * vtcm_act_g = vtcm_f16_act + (size_t) g * act_head_stride; + core_dot_chunk_fp16(vtcm_output, vtcm_act_g, vtcm_weight, vtcm_scales, n_row_tiles, n_col_tiles, + params->k / 32); + } + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, g); + + { + float *output = hmx_mm_dst_batch_ptr(params, b2_base + g, b3) + mr * params->dst_stride + nc; + int chunk_dst_cols = params->n - (int)nc; + if (chunk_dst_cols > 0) { + transfer_output_chunk_threaded(ctx, output, vtcm_output, (int) n_rows, (int) n_cols, params->dst_stride, chunk_dst_cols, ctx->n_threads); + } + } + } + } + } + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + + return 0; +} + +static void transfer_activation_chunk_gathered_threaded( + struct htp_context *ctx, + __fp16 *dst, + const float *src, + int start_row, + int n_rows, + int k_block, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + int ne11, + size_t nb11, + size_t nb12, + int cne1, + int n_threads, + int k_valid) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + activation_transfer_gathered_task_state_t state = { + .dst = dst, + .src = src, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .k_block = k_block, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .ne11 = ne11, + .ne11_div = ne11 > 1 ? init_fastdiv_values(ne11) : (struct fastdiv_values){0, 0}, + .nb11 = nb11, + .nb12 = nb12, + .start_row = start_row, + .cne1 = cne1, + .k_valid = k_valid, + .traces = ctx->trace, + }; + + worker_callback_t worker_fn = ne11 == 1 ? transfer_activation_chunk_gathered_worker_flat_fn : + transfer_activation_chunk_gathered_worker_fn; + + if (actual_threads <= 1) { + worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, worker_fn, &state, actual_threads); + } +} + +static void transfer_output_chunk_scattered_threaded( + struct htp_context *ctx, + float *dst, + const __fp16 *vtcm_src, + int start_row, + int n_rows, + int n_cols, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride, + size_t dst_nb1, + size_t dst_nb2, + int cne1, + int n_threads) { + if (n_rows <= 0) return; + int chunks_per_thread = hmx_ceil_div(n_rows, n_threads); + chunks_per_thread = hex_align_up(chunks_per_thread, 2); + + int actual_threads = hmx_ceil_div(n_rows, chunks_per_thread); + + output_transfer_scattered_task_state_t state = { + .vtcm_src = vtcm_src, + .dst = dst, + .n_tasks = actual_threads, + .n_tot_chunks = n_rows, + .n_chunks_per_task = chunks_per_thread, + .n_cols = n_cols, + .matrix_rows = matrix_rows, + .cur_a = cur_a, + .mapping_stride = mapping_stride, + .dst_nb1 = dst_nb1, + .dst_nb2 = dst_nb2, + .start_row = start_row, + .cne1 = cne1, + .traces = ctx->trace, + }; + + if (actual_threads <= 1) { + transfer_output_chunk_scattered_worker_fn(1, 0, &state); + } else { + worker_pool_run_func(ctx->worker_pool, transfer_output_chunk_scattered_worker_fn, &state, actual_threads); + } +} + +static int hmx_mm_id_2d_f32(struct htp_context *ctx, + float *restrict dst, + const float *activation, + const uint8_t *weight, + int m, int k, int n, + int k_valid, + int ne11, + size_t act_nb1, size_t act_nb2, + size_t dst_nb1, size_t dst_nb2, + int weight_stride, + int weight_type, + const struct mmid_row_mapping *matrix_rows, + int cur_a, + int mapping_stride) { + const int cne1 = m; + const int m_padded = hex_align_up(m, 32); + + if (k % 32 != 0 || n % 32 != 0) { return -1; } + if (!hex_is_aligned(dst, VLEN) || !hex_is_aligned(activation, VLEN)) { return -1; } + + size_t row_stride = htp_mm_get_tiled_row_stride(weight_type, k); + if (row_stride == 0) { + return -1; + } + + worker_callback_t dequant_worker_fn = NULL; + switch (weight_type) { + case HTP_TYPE_Q4_0: dequant_worker_fn = dequantize_tiled_worker_loop_q4_0; break; + case HTP_TYPE_IQ4_NL: dequant_worker_fn = dequantize_tiled_worker_loop_iq4_nl; break; + case HTP_TYPE_Q4_1: dequant_worker_fn = dequantize_tiled_worker_loop_q4_1; break; + case HTP_TYPE_MXFP4: dequant_worker_fn = dequantize_tiled_worker_loop_mxfp4; break; + case HTP_TYPE_Q8_0: dequant_worker_fn = dequantize_tiled_worker_loop_q8_0; break; + case HTP_TYPE_F16: dequant_worker_fn = convert_f16_worker_loop; break; + case HTP_TYPE_F32: dequant_worker_fn = quantize_f32_worker_loop; break; + default: + return -1; + } + + const int n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const struct fastdiv_values n_k_tiles_div = init_fastdiv_values(n_k_tiles); + + const int n_threads = ctx->n_threads; + const bool is_quant = (weight_type != HTP_TYPE_F16 && weight_type != HTP_TYPE_F32); + + const size_t vec_dot_size = k * sizeof(__fp16); + const size_t vtcm_budget = ctx->vtcm_size; + size_t vtcm_used = 0; + + int tile_size = htp_mm_get_weight_tile_size(weight_type); + int aligned_tile_size = htp_mm_get_weight_aligned_tile_size(weight_type); + + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + const size_t weight_row_stride = is_quant ? qweight_row_stride : row_stride; + + size_t size_per_n = 0, size_per_m = 0, size_per_mn = 0; + htp_mm_hmx_get_2d_chunk_costs(weight_type, k, /*pipeline=*/false, aligned_tile_size, + &size_per_n, &size_per_m, &size_per_mn); + + size_t m_chunk_n_rows = 0, n_chunk_n_cols = 0; + if (htp_mm_hmx_compute_chunks(vtcm_budget, /*overhead=*/256, size_per_n, size_per_m, size_per_mn, + m_padded, n, + /*m_block_cost=*/(size_t) n * HTP_MM_HMX_COST_W_DEQUANT, + /*n_block_cost=*/(size_t) m_padded * HTP_MM_HMX_COST_A_CONVERT, &m_chunk_n_rows, &n_chunk_n_cols, &vtcm_used)) { + FARF(ERROR, "hmx-mm-id-2d: VTCM too small : m %d k %d n %d budget %zu", m_padded, k, n, vtcm_budget); + return -1; + } + + const size_t weight_area_size = hex_align_up(n_chunk_n_cols * weight_row_stride, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = hex_align_up(m_chunk_n_rows * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = hex_align_up(m_chunk_n_rows * n_chunk_n_cols * sizeof(__fp16), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size = hex_align_up(n_chunk_n_cols * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + uint8_t *vtcm_ptr = (uint8_t *) ctx->vtcm_base; + __fp16 *vtcm_weight = weight_area_size ? (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, weight_area_size) : NULL; + __fp16 *vtcm_f16_act = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, act_area_size); + __fp16 *vtcm_output = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, output_area_size); + void *vtcm_scratch0 = vtcm_seq_alloc(&vtcm_ptr, scratch0_size); + __fp16 *vtcm_scales = (__fp16 *) vtcm_seq_alloc(&vtcm_ptr, 256); + + vtcm_used = vtcm_ptr - (uint8_t *) ctx->vtcm_base; + if (vtcm_used > vtcm_budget) { + FARF(ERROR, "hmx-mm-id-2d: VTCM overflow: used %zu budget %zu", vtcm_used, vtcm_budget); + return -1; + } + + hmx_init_column_scales(vtcm_scales, Q6_V_vsplat_R(0x3c00)); + + HAP_compute_res_hmx_lock(ctx->vtcm_rctx); + + for (size_t mr = 0; mr < (size_t) m_padded; mr += m_chunk_n_rows) { + const size_t n_rows = hex_smin(m_padded - mr, m_chunk_n_rows); + const size_t n_row_tiles = hmx_ceil_div(n_rows, HTP_MM_HMX_TILE_N_ROWS); + + transfer_activation_chunk_gathered_threaded( + ctx, vtcm_f16_act, activation, (int) mr, (int) n_rows, k, + matrix_rows, cur_a, mapping_stride, ne11, act_nb1, act_nb2, cne1, n_threads, k_valid); + + for (size_t nc = 0; nc < (size_t) n; nc += n_chunk_n_cols) { + const size_t n_cols = hex_smin((size_t) n - nc, n_chunk_n_cols); + const size_t n_col_tiles = hmx_ceil_div(n_cols, HTP_MM_HMX_TILE_N_COLS); + + if (is_quant) { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), aligned_tile_size, tile_size, tile_size, (n_cols / 32) * n_k_tiles); + } else { + dma_queue_push(ctx->dma[0], dma_make_ptr(vtcm_weight, weight + nc * weight_stride), row_stride, weight_stride, row_stride, n_cols); + } + dma_queue_pop(ctx->dma[0]); + + dequantize_tiled_weight_chunk_to_fp16_tiles( + ctx, vtcm_scratch0, vtcm_weight, + n_cols, k, row_stride, weight_type, + n_k_tiles, n_k_tiles_div, dequant_worker_fn, n_threads + ); + + struct htp_thread_trace * tr = &ctx->trace[HTP_MAX_NTHREADS]; + htp_trace_event_start(tr, HTP_TRACE_EVT_HMX_COMP, nc); + core_dot_chunk_fp16(vtcm_output, vtcm_f16_act, vtcm_scratch0, vtcm_scales, n_row_tiles, n_col_tiles, k / HTP_MM_HMX_TILE_N_ROWS); + htp_trace_event_stop(tr, HTP_TRACE_EVT_HMX_COMP, nc); + + transfer_output_chunk_scattered_threaded( + ctx, dst + nc, vtcm_output, (int) mr, (int) n_rows, (int) n_cols, + matrix_rows, cur_a, mapping_stride, dst_nb1, dst_nb2, cne1, n_threads); + } + } + + HAP_compute_res_hmx_unlock(ctx->vtcm_rctx); + return 0; +} + + +// --- Dispatchers and Public Entry Points --- + +static int hmx_mm_op_matmul(struct htp_ops_context * octx, const struct htp_mm_kernel_params * kparams) { htp_matmul_tensors_preamble; -#ifndef HTP_HAS_HMX - return op_matmul_hvx(octx); -#else - if (!octx->ctx->hmx_enabled) { - return op_matmul_hvx(octx); - } - - // HMX weight tile requires N to be 32-aligned. - if (src0->ne[1] % 32 != 0) { - return op_matmul_hvx(octx); - } - - // HMX supports F16, F32, Q4_0, Q8_0, IQ4_NL, MXFP4 weights. - // Other types fall back to HVX. - uint32_t wtype = src0->type; - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && wtype != HTP_TYPE_Q4_0 && wtype != HTP_TYPE_Q4_1 && wtype != HTP_TYPE_Q8_0 && wtype != HTP_TYPE_IQ4_NL && wtype != HTP_TYPE_MXFP4) { - return op_matmul_hvx(octx); - } - - // Quantised HMX path requires K aligned to 256 (x4x2 super-block). - // F16 and F32 HMX paths require K aligned to 32 (tile width). - if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && src0->ne[0] % 256 != 0) { - return op_matmul_hvx(octx); - } - - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && src0->ne[0] % 32 != 0) { - return op_matmul_hvx(octx); - } - - const bool is_batched = (src0->ne[2] * src0->ne[3] > 1 || src1->ne[2] * src1->ne[3] > 1); - - // Quantised HMX kernels only handle flat 2D matmul (host already rejects - // batched quantised, but guard here too). F16 batched matmul is handled - // by the dedicated wrapper in hmx-matmul-ops.c. - if (is_batched && src0->type != HTP_TYPE_F16) { - return op_matmul_hvx(octx); - } - - // HMX assumes contiguous row-major layout. Fall back for permuted - // tensors where strides are non-monotonic (e.g. transposed KV cache). - if (src0->nb[0] > src0->nb[1] || src1->nb[0] > src1->nb[1]) { - return op_matmul_hvx(octx); - } - - // M alignment: Use HMX when M >= 32, the last partial tile (m_total % 32 rows) - // is handled by HMX itself; when M < 32 fall back to HVX. - const int m_total = (int) src1->ne[1]; - const int m_hmx = m_total & ~31; // 0 when M < 32 - if (m_hmx == 0) { - return op_matmul_hvx(octx); - } - - // Always re-quantize src1 since HMX kernel overwrites vtcm/spad, - // so any previously cached quantized data is invalid. - octx->src1_spad.src = NULL; - - int k = (int) src0->ne[0]; // inner dimension - int n = (int) src0->ne[1]; // weight columns - - int ret = -1; - - // Row strides in elements. For compact tensors these equal k; for - // permuted attention views they can be larger, so pass the real stride. + int k = (int) src0->ne[0]; + int n = (int) src0->ne[1]; + const int m_total = (int) src1->ne[1]; const int act_stride = (int)(src1->nb[1] / sizeof(float)); const int wgt_stride = (int)(src0->nb[1] / sizeof(__fp16)); @@ -4780,54 +3051,204 @@ int op_matmul(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - if (is_batched) { - if (src0->type == HTP_TYPE_F16) { - hmx_matmul_f16_f32_batched_params_t batch_params = { - .dst = (float *) dst->data, - .activation = (float *) src1->data, - .permuted_weight = (const __fp16 *) src0->data, - .m = m_total, - .k = k, - .n = n, - .act_stride = act_stride, - .weight_stride = wgt_stride, - .dst_stride = (int) (dst->nb[1] / sizeof(float)), - .ne02 = ne02, - .ne03 = ne03, - .ne12 = ne12, - .ne13 = ne13, - .src0_nb2 = src0->nb[2], - .src0_nb3 = src0->nb[3], - .src1_nb2 = src1->nb[2], - .src1_nb3 = src1->nb[3], - .dst_nb2 = dst->nb[2], - .dst_nb3 = dst->nb[3], - }; - ret = hmx_matmul_f16_f32_batched(octx->ctx, &batch_params); - } else { - return op_matmul_hvx(octx); - } + int ret = -1; + const int n_threads = MIN(kparams->n_threads, (int) octx->n_threads); + if (kparams->kernel_type == HTP_MM_KERNEL_HMX_F16_BATCHED) { + hmx_mm_f16_f32_batched_params_t batch_params = { + .dst = (float *) dst->data, + .activation = (float *) src1->data, + .weight = (const __fp16 *) src0->data, + .m = m_total, + .k = k, + .n = n, + .act_stride = act_stride, + .weight_stride = wgt_stride, + .dst_stride = (int) (dst->nb[1] / sizeof(float)), + .ne02 = ne02, + .ne03 = ne03, + .ne12 = ne12, + .ne13 = ne13, + .src0_nb2 = src0->nb[2], + .src0_nb3 = src0->nb[3], + .src1_nb2 = src1->nb[2], + .src1_nb3 = src1->nb[3], + .dst_nb2 = dst->nb[2], + .dst_nb3 = dst->nb[3], + }; + ret = hmx_mm_f16_f32_batched(octx->ctx, &batch_params, + kparams->m_chunk, kparams->n_chunk, + kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->vtcm_size); } else { - ret = hmx_matmul_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, - m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type); + ret = hmx_mm_2d_f32( + octx->ctx, (float*) dst->data, (float*) src1->data, (const uint8_t *) src0->data, + m_total, k, n, act_stride, (int) src0->nb[1], (int) src0->type, (int) src1->ne[0], + (int)(dst->nb[1] / sizeof(float)), (int)dst->ne[0], + kparams->m_chunk, kparams->n_chunk, kparams->pipeline, n_threads, + kparams->n_act_threads, + kparams->tile_size, kparams->aligned_tile_size, kparams->vtcm_size + ); } if (ret != 0) { - FARF(HIGH, "HMX matmul failed (ret=%d), falling back to HVX", ret); - return op_matmul(octx); + FARF(ERROR, "HMX matmul failed (ret=%d)\n", ret); + return HTP_STATUS_INTERNAL_ERR; + } + return HTP_STATUS_OK; +} + +int op_matmul(struct htp_ops_context * octx) { + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + if (kparams->n_hmx) { + return hmx_mm_op_matmul(octx, kparams); } - return 0; -#endif // HTP_HAS_HMX + return hvx_mm_matmul(octx); +} + +static int hmx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + const uint32_t * matrix_row_counts, + const struct mmid_row_mapping * matrix_rows, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const int n_ids = octx->src[2]->ne[0]; + const int n_as = ne02; + + for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { + const int32_t cne1 = matrix_row_counts[cur_a]; + if (cne1 == 0) continue; + + int ret = hmx_mm_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, + (const uint8_t *) src0->data + cur_a * nb02, + cne1, ne00, ne01, + ne10, + ne11, + nb11, nb12, + nb1, nb2, + (int) src0->nb[1], (int) src0->type, + matrix_rows, cur_a, n_ids * octx->src[2]->ne[1]); + if (ret != 0) { + FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_NO_SUPPORT; + } + } + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; +} + +static int hvx_mm_op_matmul_id( + struct htp_ops_context * octx, + struct htp_mm_context * mmctx, + size_t src0_row_size_padded, + uint32_t src1_nrows, + worker_callback_t matmul_id_job_func, + void * mapping_buf, + bool must_free_mapping +) { + htp_matmul_tensors_preamble; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; + const size_t src0_row_size = nb01; + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (ne10 + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + size_t src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + // Scratchpad sizes are computed on the host (htp_mm_hvx_id_get_vtcm_sizes) and passed in. + // The ID layout is routing-independent, so the host has exact visibility -- consume it here + // rather than recomputing, to keep host budgeting and device allocation in lockstep. + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = 0; // mapping lives in DDR + size_t dst_sz = 0; // ID kernels scatter straight to DDR + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = 0; + size_t dst_sz_per_thread = 0; + + FARF(HIGH, "matmul-id-%s : src0-spad-size %zu src1-spad-size %zu src2-spad-size %zu dst-spad-size %zu (%zu)\n", mmctx->type, + src0_sz, src1_sz, src2_sz, dst_sz, vtcm_size); + + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, + src1->data, dst->data); + + // Make sure the reserved vtcm size is sufficient + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, vtcm_size); + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_dst = vtcm_seq_alloc(&vtcm_ptr, dst_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->dst_spad.src = NULL; + + mmctx->vtcm_src0_stride = src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_dst_size_per_thread = dst_sz_per_thread; + + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + const uint32_t n_matmul_jobs = octx->n_threads; + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + + if (must_free_mapping) free(mapping_buf); + return HTP_STATUS_OK; } int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - struct htp_matmul_context mmctx_struct = {0}; - struct htp_matmul_context * mmctx = &mmctx_struct; + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; mmctx->octx = octx; + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + const struct htp_tensor * restrict ids = octx->src[2]; const size_t src0_row_size = nb01; @@ -4839,14 +3260,11 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src1_nrows = ne11 * ne12 * ne13; worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? hvx_mm_id : hvx_mv_id; // Compute src0_nrows_per_thread mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even - - size_t src1_row_size; - size_t src1_row_size_padded; + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); // row groups const int n_ids = ids->ne[0]; // n_expert_used @@ -4875,54 +3293,13 @@ int op_matmul_id(struct htp_ops_context * octx) { mmctx->matrix_row_counts = matrix_row_counts; mmctx->matrix_rows = matrix_rows; + mmctx->mm_div_ne11 = kparams->div_ne11; - if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { if (must_free_mapping) free(mapping_buf); return HTP_STATUS_NO_SUPPORT; } - if (src0->type == HTP_TYPE_Q4_1) { - quant_job_func = quantize_f32_q8_1x4x2; - src1_row_size = q8_1x4x2_row_size(ne10); - } else { - quant_job_func = quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); - } - - const size_t src2_spad_size_per_thread = 0; // We moved the mapping to DDR! - htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); - - size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, - octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, - src1->data, dst->data); - - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_VTCM_TOO_SMALL; - } - - // Place src1 spad first. We use it for dyn.quant and may reuse in subseq ops. - octx->src1_spad.data = octx->ctx->vtcm_base; - octx->src0_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src2_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - - octx->src1_spad.src = (src1 == octx->src1_spad.src) ? src1 : NULL; - octx->src0_spad.src = NULL; - octx->src2_spad.src = NULL; - octx->dst_spad.src = NULL; - - octx->src0_spad.stride = src0_row_size_padded; - octx->src1_spad.stride = src1_row_size; - if (src1_nrows > 1) { // initialize matrix_row_counts and map memset(matrix_row_counts, 0, n_as * sizeof(uint32_t)); @@ -4930,9 +3307,12 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const int32_t i02 = *(const int32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); - assert(i02 >= 0 && i02 < n_as); + if (i02 < 0) { + continue; + } + assert(i02 < n_as); matrix_rows[i02 * n_ids * ids->ne[1] + matrix_row_counts[i02]] = (struct mmid_row_mapping) { id, iid1 }; matrix_row_counts[i02] += 1; @@ -4945,60 +3325,292 @@ int op_matmul_id(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - bool hmx_eligible = false; -#ifdef HTP_HAS_HMX - if (octx->ctx->hmx_enabled && src1_nrows > 1) { - uint32_t wtype = src0->type; - if (ne01 % 32 == 0 && - (wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32 || wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || wtype == HTP_TYPE_MXFP4)) { - if ((wtype == HTP_TYPE_F16 || wtype == HTP_TYPE_F32) && ne00 % 32 == 0) { - hmx_eligible = true; - } else if (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32 && ne00 % 256 == 0) { - hmx_eligible = true; - } - } + if (kparams->n_hmx) { + return hmx_mm_op_matmul_id(octx, mmctx, matrix_row_counts, matrix_rows, mapping_buf, must_free_mapping); } -#endif - mmctx->hmx_eligible = hmx_eligible; + return hvx_mm_op_matmul_id(octx, mmctx, src0_row_size_padded, src1_nrows, matmul_id_job_func, mapping_buf, must_free_mapping); +} - if (hmx_eligible) { - for (uint32_t cur_a = 0; cur_a < n_as; ++cur_a) { - const int32_t cne1 = matrix_row_counts[cur_a]; - if (cne1 == 0) continue; +int op_matmul_qkv(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wk + const struct htp_tensor * restrict src1 = octx->src[1]; // x + const struct htp_tensor * restrict src2 = octx->src[2]; // Wv + const struct htp_tensor * restrict src3 = octx->src[3]; // Wq + const struct htp_tensor * restrict dst_k = octx->dsts[0]; + const struct htp_tensor * restrict dst_v = octx->dsts[1]; + const struct htp_tensor * restrict dst_q = octx->dsts[2]; - int ret = hmx_matmul_id_2d_f32(octx->ctx, (float*) dst->data, (float*) src1->data, - (const uint8_t *) src0->data + cur_a * nb02, - cne1, ne00, ne01, - ne11, - nb11, nb12, - nb1, nb2, - (int) src0->nb[1], (int) src0->type, - matrix_rows, cur_a, n_ids * ids->ne[1]); - if (ret != 0) { - FARF(ERROR, "HMX matmul failed for expert %u, error %d\n", cur_a, ret); - if (must_free_mapping) free(mapping_buf); - return HTP_STATUS_NO_SUPPORT; - } + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } + + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * ith) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } - // HMX has overwritten VTCM, so force dynamic quantization cache to clear - octx->src1_spad.src = NULL; + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); + } - if (must_free_mapping) free(mapping_buf); + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t src3_sz = kparams->vtcm_src3_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + size_t src3_sz_per_thread = src3_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-qkv: current VTCM reservation %zu is too small, needed %zu\n", + octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + mmctx->vtcm_src3 = vtcm_seq_alloc(&vtcm_ptr, src3_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + octx->src3_spad.src = NULL; + + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src3_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + mmctx->vtcm_src3_size_per_thread = src3_sz_per_thread; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) return HTP_STATUS_OK; - } - if (octx->src1_spad.src != src1) { - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); - octx->src1_spad.src = src1; - } + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + // Run fused matmul const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_qkv_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_qkv_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_qkv_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_qkv_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + } else { + matmul_job_func = hvx_mm_qkv_2d; + } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); + + return HTP_STATUS_OK; +} + +int op_matmul_ffn(struct htp_ops_context * octx) { + const struct htp_tensor * restrict src0 = octx->src[0]; // Wgate + const struct htp_tensor * restrict src1 = octx->src[1]; // y + const struct htp_tensor * restrict src2 = octx->src[2]; // Wup + const struct htp_tensor * restrict dst_gate = octx->dsts[0]; + const struct htp_tensor * restrict dst_up = octx->dsts[1]; + + bool is_repacked = (src0->type == HTP_TYPE_Q4_0 || src0->type == HTP_TYPE_Q4_1 || + src0->type == HTP_TYPE_Q8_0 || src0->type == HTP_TYPE_IQ4_NL || + src0->type == HTP_TYPE_MXFP4); + + struct htp_mm_context mmctx_struct = {0}; + struct htp_mm_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + + const struct htp_mm_kernel_params * kparams = (const struct htp_mm_kernel_params *) octx->kernel_params; + + const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + const uint32_t src1_nrows = src1->ne[1] * src1->ne[2] * src1->ne[3]; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + if (is_repacked) { + mmctx->src0_nrows_per_thread = hex_round_up(mmctx->src0_nrows_per_thread, 32); + } else { + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + } + + const size_t src0_row_size = src0->nb[1]; + const size_t src0_row_size_padded = hex_round_up(src0_row_size, 128); + + if (hvx_mm_init_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; + } + + const uint32_t qk = QK_Q8_0_TILED; + const uint32_t nb = (src1->ne[0] + qk - 1) / qk; + const uint32_t total_nb = src1_nrows * nb; + + worker_callback_t quant_job_func; + uint32_t n_quant_jobs = 1; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_flat : quantize_f32_q8_0_flat; + } else if (src1_nrows < octx->n_threads) { + n_quant_jobs = MIN(total_nb, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled_block : quantize_f32_q8_0_tiled_block; + for (uint32_t ith = 0; ith < n_quant_jobs; ++ith) { + uint32_t ib_first = (total_nb * (ith + 0)) / n_quant_jobs; + uint32_t ib_last = (total_nb * (ith + 1)) / n_quant_jobs; + mmctx->quant_ib_first[ith] = ib_first; + mmctx->quant_ib_last[ith] = ib_last; + mmctx->quant_r[ith] = ib_first / nb; + mmctx->quant_c[ith] = ib_first % nb; + } + } else { + n_quant_jobs = MIN(src1_nrows, octx->n_threads); + quant_job_func = (src0->type == HTP_TYPE_Q4_1) ? quantize_f32_q8_1_tiled : quantize_f32_q8_0_tiled; + } + + size_t src1_row_size; + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(src1->ne[0]) : htp_mm_q8_0_flat_row_size(src1->ne[0]); + } else { + src1_row_size = (src0->type == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(src1->ne[0]) : htp_mm_q8_0_tiled_row_size(src1->ne[0]); + } + + // Set up scratchpads using precomputed sizes from the host + size_t src0_sz = kparams->vtcm_src0_size; + size_t src1_sz = kparams->vtcm_src1_size; + size_t src2_sz = kparams->vtcm_src2_size; + size_t vtcm_size = kparams->vtcm_size; + + size_t src0_sz_per_thread = src0_sz / octx->n_threads; + size_t src1_sz_per_thread = src1_sz; + size_t src2_sz_per_thread = src2_sz / octx->n_threads; + + if (octx->ctx->vtcm_size < vtcm_size) { + FARF(ERROR, "matmul-ffn: current VTCM reservation %zu is too small, needed %zu\n", octx->ctx->vtcm_size, vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + uint8_t * vtcm_ptr = (uint8_t *) octx->ctx->vtcm_base; + mmctx->vtcm_src1 = vtcm_seq_alloc(&vtcm_ptr, src1_sz); + mmctx->vtcm_src0 = vtcm_seq_alloc(&vtcm_ptr, src0_sz); + mmctx->vtcm_src2 = vtcm_seq_alloc(&vtcm_ptr, src2_sz); + + octx->src1_spad.src = NULL; + octx->src0_spad.src = NULL; + octx->src2_spad.src = NULL; + + mmctx->vtcm_src0_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src2_stride = is_repacked ? 0 : src0_row_size_padded; + mmctx->vtcm_src1_stride = src1_row_size; + + mmctx->vtcm_src0_size_per_thread = src0_sz_per_thread; + mmctx->vtcm_src1_size_per_thread = src1_sz_per_thread; + mmctx->vtcm_src2_size_per_thread = src2_sz_per_thread; + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) + return HTP_STATUS_OK; + + // Run quantization once + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); + + // Run fused matmul + const uint32_t n_matmul_jobs = octx->n_threads; + worker_callback_t matmul_job_func; + if (is_repacked) { + if (kparams->kernel_type == HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT) { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0_flat; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1_flat; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0_flat; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl_flat; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4_flat; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } else { + switch (src0->type) { + case HTP_TYPE_Q4_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_0; break; + case HTP_TYPE_Q4_1: matmul_job_func = hvx_mm_ffn_2d_repacked_q4_1; break; + case HTP_TYPE_Q8_0: matmul_job_func = hvx_mm_ffn_2d_repacked_q8_0; break; + case HTP_TYPE_IQ4_NL: matmul_job_func = hvx_mm_ffn_2d_repacked_iq4nl; break; + case HTP_TYPE_MXFP4: matmul_job_func = hvx_mm_ffn_2d_repacked_mxfp4; break; + default: return HTP_STATUS_NO_SUPPORT; + } + } + } else { + matmul_job_func = hvx_mm_ffn_2d; + } + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); - if (must_free_mapping) free(mapping_buf); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.h b/ggml/src/ggml-hexagon/htp/matmul-ops.h new file mode 100644 index 0000000000..a94d5430da --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.h @@ -0,0 +1,508 @@ +#ifndef HTP_MATMUL_OPS_H +#define HTP_MATMUL_OPS_H + +#include +#include +#include "htp-ops.h" +#include "hex-fastdiv.h" +#include "hex-common.h" + +#ifdef __cplusplus +extern "C" { +#endif + +// --- HMX Tile Constraints --- +#define HTP_MM_HMX_TILE_N_COLS 32 +#define HTP_MM_HMX_TILE_N_ROWS 32 +#define HTP_MM_HMX_TILE_SIZE (32 * 32 * sizeof(__fp16)) // 2048 bytes +#define HTP_MM_HMX_TILE_N_ELMS 1024 +#define HTP_MM_HMX_MIN_NROWS 4 + +// --- Weight Repacked Tile Sizes --- +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_0 576 +#define HTP_MM_WEIGHT_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_TILE_SIZE_Q8_0 1088 +#define HTP_MM_WEIGHT_TILE_SIZE_IQ4_NL 576 +#define HTP_MM_WEIGHT_TILE_SIZE_MXFP4 544 + +// --- Weight Repacked Aligned Tile Sizes --- +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0 1152 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_IQ4_NL 640 +#define HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4 640 + +// --- Activation Tiled Block Sizes (including padding) --- +#define HTP_MM_ACT_TILE_SIZE_Q8_0 1152 +#define HTP_MM_ACT_TILE_SIZE_Q8_1 1280 + +#define HTP_MM_MAX_PREFETCH 16 + +// --- Solver Cost Model Penalty Weights (HMX-specific) --- +#define HTP_MM_HMX_COST_W_DEQUANT 3 // cost penalty for quantized weight loading/dequantization +#define HTP_MM_HMX_COST_A_CONVERT 2 // cost penalty for activation loading/conversion + +// --- DMA Activation Transfer Configuration --- +#define HTP_MM_DMA_ACT_ROWS_PER_STEP 2 +#define HTP_MM_DMA_ACT_MULTIPLIER 4 + +enum htp_mm_kernel_type { + HTP_MM_KERNEL_UNSUPPORTED = 0, + + // HMX paths + HTP_MM_KERNEL_HMX_2D, + HTP_MM_KERNEL_HMX_F16_BATCHED, + + // HVX floating-point paths + HTP_MM_KERNEL_HVX_F16_F16_VTCM, + HTP_MM_KERNEL_HVX_F16_F16_DDR, + HTP_MM_KERNEL_HVX_F16_F32_DDR, + + HTP_MM_KERNEL_HVX_F32_F32_VTCM, + HTP_MM_KERNEL_HVX_F32_F32_DDR, + HTP_MM_KERNEL_HVX_F32_F16_DDR, + + // HVX quantized paths + HTP_MM_KERNEL_HVX_QUANT_ROW, // standard row-wise parallel quantization + HTP_MM_KERNEL_HVX_QUANT_BLOCK, // parallel block-wise quantization + HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT, // row-wise fallback flat quantization +}; + +// Op-specific struct for precomputed matmul params +struct htp_mm_kernel_params { + int32_t kernel_type; // enum htp_mm_kernel_type + int32_t pipeline; // 1 = pipelined execution, 0 = standard + int32_t m_chunk; // Row chunk size (M chunk) + int32_t n_chunk; // Col chunk size (N chunk) + int32_t n_threads; // Number of threads to spawn + int32_t n_act_threads; // Number of threads for activation preparation + int32_t n_hmx; // 1 = use HMX, 0 = use HVX + int32_t n_prefetch; // Prefetch lookahead buffers/rows in VTCM + int32_t tile_size; // Weight tile size + int32_t aligned_tile_size; // Aligned weight tile size (padded to 128) + int32_t src1_row_size; // Row size for quantized activation + int32_t vtcm_size; // Total required scratchpad size in VTCM + int32_t vtcm_src0_size; // src0 scratchpad size in VTCM + int32_t vtcm_src1_size; // src1 scratchpad size in VTCM + int32_t vtcm_src2_size; // src2 scratchpad size in VTCM (fused only) + int32_t vtcm_src3_size; // src3 scratchpad size in VTCM (fused only) + int32_t vtcm_dst_size; // dst scratchpad size in VTCM + + // Precomputed division values + struct fastdiv_values div_ne12_ne1; + struct fastdiv_values div_ne1; + struct fastdiv_values div_r2; + struct fastdiv_values div_r3; + struct fastdiv_values div_ne11; +}; + +#if defined(__cplusplus) +static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#else +_Static_assert(sizeof(struct htp_mm_kernel_params) <= 128, "htp_matmul_kernel_params is too large for kernel_params blob"); +#endif + +struct mmid_row_mapping { + uint32_t i1; + uint32_t i2; +}; + +// Search for optimal (mc, nc) chunk sizes within VTCM budget. +static inline int htp_mm_hmx_compute_chunks(size_t vtcm_total, + size_t overhead, + size_t per_n_cost, + size_t per_m_cost, + size_t per_mn_cost, + size_t m, + size_t n, + size_t m_block_cost, + size_t n_block_cost, + size_t * m_chunk_out, + size_t * n_chunk_out, + size_t * total_out) { + if (m == 0 || n == 0) return -1; + if (vtcm_total <= overhead) return -1; + if (per_n_cost == 0 || per_m_cost == 0 || per_mn_cost == 0) return -1; + + const size_t usable = vtcm_total - overhead; + + size_t best_cost = SIZE_MAX; + size_t best_mn = 0; + size_t best_m = 0, best_n = 0; + + const size_t n_max = hex_align_down((size_t)n, HTP_MM_HMX_TILE_N_COLS); + for (size_t nc = n_max; nc >= HTP_MM_HMX_TILE_N_COLS; nc -= HTP_MM_HMX_TILE_N_COLS) { + size_t n_fixed = 0, ncmn = 0, mc_denom = 0; + if (hex_mul_overflow(nc, per_n_cost, &n_fixed)) continue; + if (n_fixed >= usable) goto next_nc; + + if (hex_mul_overflow(nc, per_mn_cost, &ncmn)) goto next_nc; + if (hex_add_overflow(per_m_cost, ncmn, &mc_denom) || mc_denom == 0) goto next_nc; + + { + size_t remain = usable - n_fixed; + size_t mc = remain / mc_denom; + mc = hex_align_down(mc, HTP_MM_HMX_TILE_N_ROWS); + mc = hex_smin(mc, m); + + if (mc == 0) { + goto next_nc; + } + + size_t mblocks = ((size_t) m + mc - 1) / mc; + size_t nblocks = ((size_t) n + nc - 1) / nc; + size_t cost = mblocks * m_block_cost + nblocks * n_block_cost; + size_t mn = mc * nc; + if (cost < best_cost || (cost == best_cost && mn > best_mn)) { + best_cost = cost; + best_mn = mn; + best_m = mc; + best_n = nc; + } + } + +next_nc: + if (nc == HTP_MM_HMX_TILE_N_COLS) break; // avoid size_t underflow + } + + if (best_m == 0 || best_n == 0) return -1; + + // Compute exact total (with overflow checks) + size_t t0 = 0, t1 = 0, t2 = 0, mn = 0, total = 0; + if (hex_mul_overflow(best_n, per_n_cost, &t0)) return -1; + if (hex_mul_overflow(best_m, per_m_cost, &t1)) return -1; + if (hex_mul_overflow(best_m, best_n, &mn)) return -1; + if (hex_mul_overflow(mn, per_mn_cost, &t2)) return -1; + if (hex_add_overflow(t0, t1, &total)) return -1; + if (hex_add_overflow(total, t2, &total)) return -1; + if (hex_add_overflow(total, overhead, &total)) return -1; + + *m_chunk_out = best_m; + *n_chunk_out = best_n; + *total_out = total; + return 0; +} + +// --- Tile Size Helpers --- +static inline uint32_t htp_mm_get_weight_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +static inline uint32_t htp_mm_get_weight_aligned_tile_size(int weight_type) { + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_0; + case HTP_TYPE_Q4_1: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q4_1; + case HTP_TYPE_Q8_0: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_Q8_0; + case HTP_TYPE_MXFP4: + return HTP_MM_WEIGHT_ALIGNED_TILE_SIZE_MXFP4; + default: + return 0; + } +} + +// --- Activation/Row Size Helpers --- +static inline size_t htp_mm_q8_0_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_0; +} + +static inline size_t htp_mm_q8_1_tiled_row_size(uint32_t ne) { + const uint32_t ne_padded = ((ne + 127) / 128) * 128; + const uint32_t nb_32 = ne_padded / 32; + return nb_32 * HTP_MM_ACT_TILE_SIZE_Q8_1; +} + +static inline size_t htp_mm_q8_0_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 2, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_q8_1_flat_row_size(uint32_t ne) { + const uint32_t quants_size = hex_align_up(ne, 128); + const uint32_t num_scales = (ne + 31) / 32; + const uint32_t scales_size = hex_align_up(num_scales * 4, 128); + return quants_size + scales_size; +} + +static inline size_t htp_mm_get_tiled_row_stride(int weight_type, uint32_t k) { + uint32_t nb = (k + QK_Q4_0_TILED - 1) / QK_Q4_0_TILED; + switch (weight_type) { + case HTP_TYPE_Q4_0: + case HTP_TYPE_IQ4_NL: + case HTP_TYPE_Q4_1: + case HTP_TYPE_Q8_0: + case HTP_TYPE_MXFP4: + return (size_t) nb * htp_mm_get_weight_tile_size(weight_type); + case HTP_TYPE_F16: + return (size_t) k * sizeof(__fp16); + case HTP_TYPE_F32: + return (size_t) k * sizeof(float); + default: + return 0; + } +} + +static inline size_t htp_mm_round_up(size_t n, size_t m) { + return ((n + m - 1) / m) * m; +} + +static inline bool htp_mm_hmx_pipeline(uint32_t m) { + return m > 32; +} + +static inline void htp_mm_hmx_get_2d_chunk_costs( + int wtype, uint32_t k, bool pipeline, uint32_t aligned_tile_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const size_t qweight_row_stride = is_quant ? (size_t)(n_k_tiles * aligned_tile_size) / 32 : 0; + + *size_per_n_out = (pipeline ? 2 : 1) * (is_quant ? qweight_row_stride : row_stride) + + (pipeline ? 2 * vec_dot_size : vec_dot_size); + *size_per_m_out = vec_dot_size; + *size_per_mn_out = (pipeline ? 2 : 1) * sizeof(uint16_t); +} + +static inline void htp_mm_hmx_get_batched_chunk_costs( + uint32_t k, uint32_t group_size, + size_t * size_per_n_out, size_t * size_per_m_out, size_t * size_per_mn_out +) { + const size_t vec_dot_size = k * sizeof(uint16_t); + *size_per_n_out = 3 * vec_dot_size; + *size_per_m_out = group_size * vec_dot_size; + *size_per_mn_out = sizeof(uint16_t); +} + +static inline size_t htp_mm_hmx_get_2d_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, bool pipeline, uint32_t act_threads, uint32_t aligned_tile_size +) { + const uint32_t n_k_tiles = k / HTP_MM_HMX_TILE_N_COLS; + const bool is_quant = (wtype != HTP_TYPE_F16 && wtype != HTP_TYPE_F32); + const size_t row_stride = htp_mm_get_tiled_row_stride(wtype, k); + const size_t vec_dot_size = k * sizeof(uint16_t); + + const size_t act_f32_size = htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE); + size_t weight_area_size = is_quant + ? htp_mm_round_up((nc / 32) * n_k_tiles * aligned_tile_size, HTP_MM_HMX_TILE_SIZE) + : htp_mm_round_up(nc * row_stride, HTP_MM_HMX_TILE_SIZE); + if (pipeline) { + weight_area_size *= 2; + } + const size_t act_area_size = htp_mm_round_up(mc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + + size_t scratch0_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + size_t scratch1_size = pipeline ? scratch0_size : 0; + size_t scratch2_size = pipeline ? output_area_size : 0; + + return weight_area_size + act_area_size + act_f32_size + output_area_size + + scratch0_size + scratch1_size + scratch2_size + 256; +} + +static inline size_t htp_mm_hmx_get_batched_vtcm_size( + int wtype, uint32_t k, size_t mc, size_t nc, uint32_t group_size, bool use_dma_activation, bool pipeline, uint32_t act_threads) { + (void)wtype; + (void)pipeline; + const size_t vec_dot_size = k * sizeof(uint16_t); + const size_t f32_scratch_size = use_dma_activation + ? htp_mm_round_up(act_threads * 4 * k * sizeof(float), HTP_MM_HMX_TILE_SIZE) : 0; + + const size_t act_head_stride = mc * k; + const size_t weight_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + const size_t act_area_size = htp_mm_round_up(group_size * act_head_stride * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t output_area_size = htp_mm_round_up(group_size * mc * nc * sizeof(uint16_t), HTP_MM_HMX_TILE_SIZE); + const size_t scratch_area_size = htp_mm_round_up(nc * vec_dot_size, HTP_MM_HMX_TILE_SIZE); + + return weight_area_size + act_area_size + output_area_size + + 2 * scratch_area_size + 256 + f32_scratch_size; +} + +static inline size_t htp_mm_hvx_get_vtcm_sizes( + int kernel_type, + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, // m_total (or act_nrows) + uint32_t n_threads, + size_t dst_row_size, + size_t src0_row_size, + size_t src1_row_size, + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out, + size_t * vtcm_dst_size_out +) { + size_t vtcm_src0_size = 0; + size_t vtcm_src1_size = 0; + size_t vtcm_dst_size = 0; + + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t dst_nrows = (src1_nrows > 1) ? 0 : 1; + + switch (kernel_type) { + case HTP_MM_KERNEL_HVX_F16_F16_VTCM: { + size_t f16_src1_row_size = htp_mm_round_up(ne10 * 2, 128); + vtcm_src1_size = htp_mm_round_up(f16_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F16_F32_DDR: + case HTP_MM_KERNEL_HVX_F16_F16_DDR: + case HTP_MM_KERNEL_HVX_F32_F32_DDR: + case HTP_MM_KERNEL_HVX_F32_F16_DDR: { + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size, 256) * n_threads; + vtcm_src1_size = htp_mm_round_up(n_prefetch * src1_row_size, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_F32_F32_VTCM: { + size_t f32_src1_row_size = htp_mm_round_up(ne10 * 4, 128); + vtcm_src1_size = htp_mm_round_up(f32_src1_row_size * src1_nrows, 256); + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256) * n_threads; + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) * n_threads : 0; + break; + } + case HTP_MM_KERNEL_HVX_QUANT_BLOCK: + case HTP_MM_KERNEL_HVX_QUANT_ROW: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) : htp_mm_q8_0_tiled_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + case HTP_MM_KERNEL_HVX_QUANT_ROW_FLAT: { + size_t q_src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_flat_row_size(ne10) : htp_mm_q8_0_flat_row_size(ne10); + + vtcm_dst_size = dst_nrows > 0 ? htp_mm_round_up(dst_row_size, 128) : 0; + vtcm_src0_size = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + vtcm_src1_size = htp_mm_round_up(q_src1_row_size * src1_nrows, 256); + + size_t src1_row_size_padded = htp_mm_round_up(q_src1_row_size, 256); + if (vtcm_src0_size < src1_row_size_padded) { + vtcm_src0_size = src1_row_size_padded; + } + + vtcm_src0_size = vtcm_src0_size * n_threads; + vtcm_dst_size = vtcm_dst_size * n_threads; + + if (is_repack) { + uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + uint32_t n_k_tiles = ne10 / 32; + uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + vtcm_src0_size = repacked_vtcm_size * n_threads; + } + break; + } + default: + break; + } + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = vtcm_src1_size; + *vtcm_dst_size_out = vtcm_dst_size; + + return vtcm_src0_size + vtcm_src1_size + vtcm_dst_size; +} + +static inline size_t htp_mm_hvx_id_get_vtcm_sizes( + int wtype, + uint32_t ne10, // k + uint32_t src1_nrows, + uint32_t n_threads, + size_t src0_row_size, // nb01 + uint32_t n_prefetch, + size_t * vtcm_src0_size_out, + size_t * vtcm_src1_size_out +) { + const bool is_repack = (wtype == HTP_TYPE_Q4_0 || wtype == HTP_TYPE_Q4_1 || + wtype == HTP_TYPE_Q8_0 || wtype == HTP_TYPE_IQ4_NL || + wtype == HTP_TYPE_MXFP4); + + const size_t src0_row_size_padded = htp_mm_round_up(src0_row_size, 128); + const size_t src1_row_size = (wtype == HTP_TYPE_Q4_1) ? htp_mm_q8_1_tiled_row_size(ne10) + : htp_mm_q8_0_tiled_row_size(ne10); + + size_t src0_sz_per_thread = htp_mm_round_up(n_prefetch * src0_row_size_padded, 256); + size_t src1_sz = htp_mm_round_up(src1_row_size * src1_nrows, 256); + + // src0 spad also holds temporary transposed src1 columns during dynamic quantization. + const size_t src1_row_size_padded = htp_mm_round_up(src1_row_size, QK_Q8_0_TILED * sizeof(float)); + if (src0_sz_per_thread < src1_row_size_padded) { + src0_sz_per_thread = src1_row_size_padded; + } + + if (is_repack) { + const uint32_t aligned_tile_size = htp_mm_get_weight_aligned_tile_size(wtype); + const uint32_t n_k_tiles = ne10 / 32; + const uint32_t tile_row_size = n_k_tiles * aligned_tile_size; + size_t repacked_vtcm_size = htp_mm_round_up(n_prefetch * tile_row_size, 256); + if (repacked_vtcm_size < src1_row_size_padded) { + repacked_vtcm_size = src1_row_size_padded; + } + src0_sz_per_thread = repacked_vtcm_size; + } + + const size_t vtcm_src0_size = src0_sz_per_thread * n_threads; + + *vtcm_src0_size_out = vtcm_src0_size; + *vtcm_src1_size_out = src1_sz; + + return vtcm_src0_size + src1_sz; +} + +#ifdef __cplusplus +} +#endif + +#endif // HTP_MATMUL_OPS_H diff --git a/ggml/src/ggml-hexagon/htp/ssm-conv.c b/ggml/src/ggml-hexagon/htp/ssm-conv.c index d574da2e2b..a48bc9ed86 100644 --- a/ggml/src/ggml-hexagon/htp/ssm-conv.c +++ b/ggml/src/ggml-hexagon/htp/ssm-conv.c @@ -183,24 +183,25 @@ static inline void hvx_transpose_32x32_f32(HVX_Vector m[32]) { // transposed into VTCM. // // VTCM layouts (per thread): -// src1_T : {d_inner_per_thread, d_conv} โ€” staged once per launch (small). -// src0_T : {d_inner_tile, ncs} โ€” staged per d_inner-tile. +// src1_T : {d_inner_stride, d_conv} - staged once per launch (small). +// src0_T : {d_inner_tile, ncs} - staged per d_inner-tile. // // d_inner_tile is chosen so that per-thread VTCM stays under the budget. // Each thread iterates ceil(d_inner_per_thread d_inner_tile) tiles serially. #define HTP_SSM_CONV_VTCM_BUDGET (1u << 20) // 1 MiB per thread -// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_per_thread, d_conv} (VTCM) +// Scalar transpose: src1 {d_conv, d_inner} (DDR) -> {d_inner_stride, d_conv} (VTCM) static inline void transpose_src1(const float * src1_data, uint32_t src1_stride_inner, uint32_t i1_off, uint32_t d_inner_per_thread, + uint32_t d_inner_stride, uint32_t d_conv, float * src1_T) { for (uint32_t i = 0; i < d_inner_per_thread; ++i) { const float * src_row = src1_data + (i1_off + i) * src1_stride_inner; for (uint32_t j = 0; j < d_conv; ++j) { - src1_T[j * d_inner_per_thread + i] = src_row[j]; + src1_T[j * d_inner_stride + i] = src_row[j]; } } } @@ -280,6 +281,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void } const uint32_t d_inner_per_thread = ir1 - ir0; + const uint32_t d_inner_stride = scctx->nrows_per_thread; const uint32_t d_inner_tile = scctx->d_inner_tile; const float * src0_data = (const float *) src0->data; @@ -290,8 +292,8 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void float * src0_T = (float *)(octx->src0_spad.data + ith * octx->src0_spad.size_per_thread); float * src1_T = (float *)(octx->src1_spad.data + ith * octx->src1_spad.size_per_thread); - // Stage src1 weights once into VTCM in {d_inner_per_thread, d_conv} layout. - transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_conv, src1_T); + // Stage src1 weights once into VTCM in {d_inner_stride, d_conv} layout. + transpose_src1(src1_data, src1_stride_inner, ir0, d_inner_per_thread, d_inner_stride, d_conv, src1_T); const uint32_t C_TILE = VLEN_FP32; @@ -314,7 +316,7 @@ static void ssm_conv_thread_f32_f32_hvx(unsigned int nth, unsigned int ith, void HVX_Vector acc = hvx_vec_splat_f32(0.0f); for (uint32_t j = 0; j < d_conv; ++j) { HVX_Vector x = *(const HVX_Vector *) (src0_T + (t + j) * d_inner_tile + cb); - HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_per_thread + tile_off + cb); + HVX_Vector w = *(const HVX_Vector *) (src1_T + j * d_inner_stride + tile_off + cb); acc = Q6_Vqf32_vadd_Vqf32Vqf32(acc, Q6_Vqf32_vmpy_VsfVsf(x, w)); } HVX_Vector res = Q6_Vsf_equals_Vqf32(acc); @@ -362,8 +364,7 @@ int op_ssm_conv_f32(struct htp_ops_context * octx) { use_hvx = 1; } - scctx.nrows_per_thread = (d_inner + n_threads - 1) / n_threads; - scctx.nrows_per_thread += (scctx.nrows_per_thread & 1); + scctx.nrows_per_thread = hex_round_up((d_inner + n_threads - 1) / n_threads, VLEN_FP32); const uint32_t d_inner_per_thread = scctx.nrows_per_thread; const uint32_t ncs = src0->ne[0]; diff --git a/ggml/src/ggml-hexagon/libggml-htp.inf b/ggml/src/ggml-hexagon/libggml-htp.inf index 39cefcdda3..874dde1b88 100644 --- a/ggml/src/ggml-hexagon/libggml-htp.inf +++ b/ggml/src/ggml-hexagon/libggml-htp.inf @@ -14,8 +14,6 @@ Drivers_Dir = 13 1 = %DiskId% [SourceDisksFiles] -libggml-htp-v68.so = 1 -libggml-htp-v69.so = 1 libggml-htp-v73.so = 1 libggml-htp-v75.so = 1 libggml-htp-v79.so = 1 @@ -28,8 +26,6 @@ ExcludeFromSelect = * CopyFiles=Drivers_Dir [Drivers_Dir] -libggml-htp-v68.so,,,0x10 ;COPYFLG_NO_OVERWRITE -libggml-htp-v69.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v73.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v75.so,,,0x10 ;COPYFLG_NO_OVERWRITE libggml-htp-v79.so,,,0x10 ;COPYFLG_NO_OVERWRITE diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 5ad8d76fa5..fb330e0625 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -10152,14 +10152,8 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int ne00 = src0 ? src0->ne[0] : 0; - const int ne01 = src0 ? src0->ne[1] : 0; - const int ne02 = src0 ? src0->ne[2] : 0; - const int ne03 = src0 ? src0->ne[3] : 0; - - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const cl_ulong nb03 = src0 ? src0->nb[3] : 0; + GGML_TENSOR_LOCALS(int, ne0, src0, ne); + GGML_TENSOR_LOCALS(cl_ulong, nb0, src0, nb); const int nth = MIN(64, ne00); @@ -10173,11 +10167,12 @@ static void ggml_cl_norm(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(float), &eps)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float)*nth, NULL)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(float), &eps)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(float)*nth, NULL)); size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)nth, 1, 1}; diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl index 9703b693e5..f5c6fb3e84 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl @@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32( regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; - dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB); } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-opencl/kernels/norm.cl b/ggml/src/ggml-opencl/kernels/norm.cl index 170f822787..a5ccac2413 100644 --- a/ggml/src/ggml-opencl/kernels/norm.cl +++ b/ggml/src/ggml-opencl/kernels/norm.cl @@ -24,6 +24,7 @@ kernel void kernel_norm( int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, @@ -43,7 +44,8 @@ kernel void kernel_norm( // parallel sum sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - sum[get_local_id(0)] += x[i00]; + // this kernel handles float, nb00/4 translates byte offset to element offset + sum[get_local_id(0)] += x[i00*nb00/4]; } // reduce barrier(CLK_LOCAL_MEM_FENCE); @@ -60,7 +62,8 @@ kernel void kernel_norm( global float * y = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; sum[get_local_id(0)] = 0.0f; for (int i00 = get_local_id(0); i00 < ne00; i00 += get_local_size(0)) { - y[i00] = x[i00] - mean; + // this kernel handles float, nb00/4 translates byte offset to element offset + y[i00] = x[i00*nb00/4] - mean; sum[get_local_id(0)] += y[i00] * y[i00]; } diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index ad2e6ca35e..306eeddc0c 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -293,6 +293,11 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); + } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_BF16) { + op()((const sycl::ext::oneapi::bfloat16 *) src0->data, (const float *) src1->data, + (sycl::ext::oneapi::bfloat16 *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, + ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), + ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); #endif } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), diff --git a/ggml/src/ggml-sycl/conv3d.cpp b/ggml/src/ggml-sycl/conv3d.cpp index 2fa29f9305..3796562553 100644 --- a/ggml/src/ggml-sycl/conv3d.cpp +++ b/ggml/src/ggml-sycl/conv3d.cpp @@ -103,8 +103,8 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { // allocate packed arrays: A_packed (k x m), B_packed (k x n) ggml_sycl_pool_alloc A_packed_alloc(ctx.pool()); ggml_sycl_pool_alloc B_packed_alloc(ctx.pool()); - A_packed_alloc.alloc((size_t) knl_n_total * patch_total * sizeof(float)); - B_packed_alloc.alloc((size_t) knl_n_total * oc * sizeof(float)); + A_packed_alloc.alloc((size_t) knl_n_total * patch_total); + B_packed_alloc.alloc((size_t) knl_n_total * oc); float * A_packed = A_packed_alloc.get(); float * B_packed = B_packed_alloc.get(); @@ -115,10 +115,16 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { // Combined kernel: im2col -> pack A, and pack B simultaneously const char * src1_base = (const char *) src1->data; + const char * src0_base = (const char *) src0->data; const int64_t src1_nb0 = src1->nb[0]; const int64_t src1_nb1 = src1->nb[1]; const int64_t src1_nb2 = src1->nb[2]; const int64_t src1_nb3 = src1->nb[3]; + const int64_t src1_w = src1->ne[0]; + const int64_t src1_h = src1->ne[1]; + const int64_t src1_d = src1->ne[2]; + + const bool src0_is_f32 = (src0->type == GGML_TYPE_F32); // Compute correct strides for src0 as (knl_n_total, oc) matrix const int64_t src0_packed_nb0 = kernel_type_size; @@ -165,7 +171,7 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const int64_t sz = dst_z * s2 + kz * d2 - p2; float val = 0.0f; - if (sx >= 0 && sx < src1->ne[0] && sy >= 0 && sy < src1->ne[1] && sz >= 0 && sz < src1->ne[2]) { + if (sx >= 0 && sx < src1_w && sy >= 0 && sy < src1_h && sz >= 0 && sz < src1_d) { const int64_t channel_idx = batch_idx * c + ic; const char * ptr = src1_base + sx * src1_nb0 + sy * src1_nb1 + sz * src1_nb2 + channel_idx * src1_nb3; val = *(const float *) ptr; @@ -184,9 +190,9 @@ void ggml_sycl_op_conv_3d(ggml_backend_sycl_context & ctx, ggml_tensor * dst) { const int64_t row = t % k; const int64_t col = t / k; - const char * src_ptr = (const char *) src0->data + row * src0_packed_nb0 + col * src0_packed_nb1; + const char * src_ptr = src0_base + row * src0_packed_nb0 + col * src0_packed_nb1; float v; - if (src0->type == GGML_TYPE_F32) { + if (src0_is_f32) { v = *(const float *) src_ptr; } else { v = sycl::vec(*(const sycl::half *) src_ptr).convert()[0]; diff --git a/ggml/src/ggml-sycl/element_wise.cpp b/ggml/src/ggml-sycl/element_wise.cpp index aca68e58ee..0c82ceb969 100644 --- a/ggml/src/ggml-sycl/element_wise.cpp +++ b/ggml/src/ggml-sycl/element_wise.cpp @@ -43,14 +43,44 @@ static __dpct_inline__ T op_sgn(T x) { return x > static_cast(0.f) ? static_cast(1.f) : ((x < static_cast(0.f) ? static_cast(-1.f) : static_cast(0.f))); } + template static __dpct_inline__ T op_abs(T x) { - return sycl::fabs(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fabs(x); // or experimental namespace if needed + } else { + return sycl::fabs(x); + } +} + +template +static __dpct_inline__ T op_expm1(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::expm1(static_cast(x)) + ); + } else { + return sycl::expm1(x); + } } template static __dpct_inline__ T op_elu(T x) { - return (x > static_cast(0.f)) ? x : sycl::expm1(x); + return (x > static_cast(0.f)) ? x : op_expm1(x); +} + +template +static __dpct_inline__ T op_tanh(T x) { + if constexpr (std::is_same_v) { + constexpr int ver = __INTEL_LLVM_COMPILER; +#if defined(__INTEL_LLVM_COMPILER) && (__INTEL_LLVM_COMPILER >= 20260000) + return sycl::ext::oneapi::experimental::tanh(x); +#else + return static_cast(sycl::tanh(static_cast(x))); +#endif + } else { + return sycl::tanh(x); + } } template @@ -59,74 +89,106 @@ static __dpct_inline__ T op_gelu(T x) { const T SQRT_2_OVER_PI = static_cast(0.79788456080286535587989211986876f); return static_cast(0.5f) * x * (static_cast(1.0f) + - sycl::tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); + op_tanh(SQRT_2_OVER_PI * x * (static_cast(1.0f) + GELU_COEF_A * x * x))); +} + +template +static __dpct_inline__ T op_exp(T x) { + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::exp(x); + } else { + return sycl::exp(x); + } } template static __dpct_inline__ T op_silu(T x) { - return x / (static_cast(1.0f) + sycl::native::exp(-x)); + return x / (static_cast(1.0f) + op_exp(-x)); } template -static __dpct_inline__ T op_gelu_quick(T x) { - const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); - return x * (static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(GELU_QUICK_COEF_LOCAL * x))); +static __dpct_inline__ T op_erf(T x) { + if constexpr (std::is_same_v) { + return static_cast( + sycl::erf(static_cast(x)) + ); + } else { + return sycl::erf(x); + } } template static __dpct_inline__ T op_gelu_erf(T x) { const T SQRT_2_INV = static_cast(0.70710678118654752440084436210484f); - return static_cast(0.5f) * x * (static_cast(1.0f) + sycl::erf(x * SQRT_2_INV)); + return static_cast(0.5f) * x * (static_cast(1.0f) + op_erf(x * SQRT_2_INV)); } template -static __dpct_inline__ T op_tanh(T x) { - return sycl::tanh(x); +static __dpct_inline__ T op_gelu_quick(T x) { + const T GELU_QUICK_COEF_LOCAL = static_cast(-1.702f); + return x * (static_cast(1.0f) / (static_cast(1.0f) + op_exp(GELU_QUICK_COEF_LOCAL * x))); } template static __dpct_inline__ T op_relu(T x) { - return sycl::fmax(x, static_cast(0)); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)); + } else { + return sycl::fmax(x, static_cast(0)); + } } template static __dpct_inline__ T op_sigmoid(T x) { - return static_cast(1.0f) / (static_cast(1.0f) + sycl::native::exp(-x)); + return static_cast(1.0f) / (static_cast(1.0f) + op_exp(-x)); } template static __dpct_inline__ T op_sqrt(T x) { - return sycl::sqrt(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sqrt(x); + } else { + return sycl::sqrt(x); + } } template static __dpct_inline__ T op_sin(T x) { - return sycl::sin(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::sin(x); + } else { + return sycl::sin(x); + } } template static __dpct_inline__ T op_cos(T x) { - return sycl::cos(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::cos(x); + } else { + return sycl::cos(x); + } } template static __dpct_inline__ T op_hardsigmoid(T x) { - return sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmin( + static_cast(1.0f), sycl::ext::oneapi::experimental::fmax( + static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return sycl::fmin(static_cast(1.0f), + sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template static __dpct_inline__ T op_hardswish(T x) { - return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); -} - -template -static __dpct_inline__ T op_exp(T x) { - return sycl::exp(x); -} - -template -static __dpct_inline__ T op_expm1(T x) { - return sycl::expm1(x); + if constexpr (std::is_same_v) { + return x * sycl::ext::oneapi::experimental::fmin(static_cast(1.0f), sycl::ext::oneapi::experimental::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } else { + return x * sycl::fmin(static_cast(1.0f), sycl::fmax(static_cast(0.0f), (x + static_cast(3.0f)) / static_cast(6.0f))); + } } template @@ -134,13 +196,17 @@ static __dpct_inline__ T op_log(T x) { if (x <= static_cast(0)) { return neg_infinity(); } - return sycl::log(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::log(x); + } else { + return sycl::log(x); + } } template static __dpct_inline__ T op_softplus(T x) { const float xf = (float) x; - const float ax = sycl::fabs(xf); + const float ax = op_abs(xf); const float m = sycl::fmax(xf, 0.0f); const float y = m + sycl::log1p(sycl::exp(-ax)); return (T) y; @@ -159,8 +225,14 @@ static __dpct_inline__ T op_step(T x) { template static __dpct_inline__ T op_leaky_relu(T x, float negative_slope) { T neg_slope_T = static_cast(negative_slope); - return sycl::fmax(x, static_cast(0)) + + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::fmax(x, static_cast(0)) + + sycl::ext::oneapi::experimental::fmin(x, static_cast(0.0f)) * neg_slope_T; + + } else { + return sycl::fmax(x, static_cast(0)) + sycl::fmin(x, static_cast(0.0f)) * neg_slope_T; + } } template @@ -175,22 +247,40 @@ static __dpct_inline__ T op_clamp(T x, float min_val, float max_val) { template static __dpct_inline__ T op_floor(T x) { - return sycl::floor(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::floor(x); + } else { + return sycl::floor(x); + } } template static __dpct_inline__ T op_ceil(T x) { - return sycl::ceil(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::ceil(x); + } else { + return sycl::ceil(x); + } } template static __dpct_inline__ T op_round(T x) { - return sycl::round(x); + if constexpr (std::is_same_v) { + return static_cast( + sycl::round(static_cast(x)) + ); + } else { + return sycl::round(x); + } } template static __dpct_inline__ T op_trunc(T x) { - return sycl::trunc(x); + if constexpr (std::is_same_v) { + return sycl::ext::oneapi::experimental::trunc(x); + } else { + return sycl::trunc(x); + } } template @@ -339,7 +429,7 @@ static void acc_f32_sycl(const float *x, const float *y, float *dst, const int num_blocks = (n_elements + SYCL_ACC_BLOCK_SIZE - 1) / SYCL_ACC_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_ACC_BLOCK_SIZE)), - [=](sycl::nd_item<3> /*item_ct1*/) { + [=](sycl::nd_item<3> /*item_ct1*/) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { acc_f32(x, y, dst, n_elements, ne10, ne11, ne12, ne13, s1, s2, s3, offset); }); } @@ -354,8 +444,8 @@ static void arange_kernel(T * dst, const int k, T start, T step, template static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, ggml_tensor * dst, KernelInvoker kernel_invoker, Args&&... args) { - GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16); - GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32 || dst->src[0]->type == GGML_TYPE_F16 || dst->src[0]->type == GGML_TYPE_BF16); + GGML_ASSERT(dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_BF16); GGML_ASSERT(dst->src[0]->type == dst->type); dpct::queue_ptr main_stream = ctx.stream(); @@ -367,6 +457,14 @@ static inline void dispatch_ggml_sycl_op_unary(ggml_backend_sycl_context & ctx, kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); break; } +#ifdef GGML_SYCL_HAS_BF16 + case GGML_TYPE_BF16: + { + auto data_pts = cast_data(dst); + kernel_invoker(data_pts.src, data_pts.dst, (int)ggml_nelements(dst->src[0]), main_stream, std::forward(args)...); + break; + } +#endif case GGML_TYPE_F32: { auto data_pts = cast_data(dst); @@ -480,7 +578,7 @@ static inline void ggml_sycl_op_unary( stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256), sycl::range<1>(256)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_generic_kernel( src, dst_ptr, k_elements, ne0, ne1, ne2, ne3, @@ -508,7 +606,7 @@ static inline void ggml_sycl_op_arange(ggml_backend_sycl_context & ctx, ggml_ten stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE), sycl::range<1>(SYCL_ARANGE_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { arange_kernel(dst_ptr, k, start, step, item_ct1); }); } @@ -602,7 +700,7 @@ static inline void ggml_sycl_op_log(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_EXP_BLOCK_SIZE), sycl::range<1>(SYCL_EXP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_log_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -640,7 +738,7 @@ static inline void ggml_sycl_op_sqrt(ggml_backend_sycl_context & ctx, ggml_tenso stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQRT_BLOCK_SIZE), sycl::range<1>(SYCL_SQRT_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqrt_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -653,7 +751,7 @@ static inline void ggml_sycl_op_sin(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sin_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -666,7 +764,7 @@ static inline void ggml_sycl_op_cos(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SIN_BLOCK_SIZE), sycl::range<1>(SYCL_SIN_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_cos_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -681,7 +779,7 @@ static inline void ggml_sycl_op_leaky_relu(ggml_backend_sycl_context & ctx, ggml stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_RELU_BLOCK_SIZE), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_leaky_relu_kernel(src, dst_ptr, k_elements, slope, item_ct1); }); }, negative_slope); @@ -694,7 +792,7 @@ static inline void ggml_sycl_op_sqr(ggml_backend_sycl_context & ctx, ggml_tensor stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_SQR_BLOCK_SIZE), sycl::range<1>(SYCL_SQR_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { unary_op_sqr_kernel(src, dst_ptr, k_elements, item_ct1); }); }); @@ -711,7 +809,7 @@ static inline void ggml_sycl_op_clamp(ggml_backend_sycl_context & ctx, ggml_tens stream->parallel_for( sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE), sycl::range<1>(SYCL_CLAMP_BLOCK_SIZE)), - [=](sycl::nd_item<1> item_ct1) { + [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { clamp(src, dst_ptr, min_arg, max_arg, k_elements, item_ct1); }); }, min_val, max_val); @@ -774,7 +872,8 @@ static inline void ggml_sycl_op_geglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -785,7 +884,8 @@ static inline void ggml_sycl_op_reglu(ggml_backend_sycl_context & ctx, ggml_tens [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_RELU_BLOCK_SIZE); // Using RELU block size for reglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_RELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_reglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -796,7 +896,8 @@ static inline void ggml_sycl_op_swiglu(ggml_backend_sycl_context & ctx, ggml_ten [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div((uint32_t)k, SYCL_SILU_BLOCK_SIZE); // Using SILU block size for swiglu main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), + sycl::range<1>(SYCL_SILU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_swiglu(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -811,7 +912,6 @@ __dpct_inline__ float ggml_sycl_op_swiglu_oai_single(float x, float g, float alp return out_glu; } - template static void swiglu_oai_kernel(const T * x, const T * g, T * dst, const int64_t k, const int64_t n, const int64_t o0, const int64_t o1, @@ -845,7 +945,7 @@ static void swiglu_oai_sycl(const T * x, const int64_t num_blocks = (k + SYCL_GLU_BLOCK_SIZE - 1) / SYCL_GLU_BLOCK_SIZE; stream->parallel_for(sycl::nd_range<3>(sycl::range<3>(1, 1, num_blocks) * sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE), sycl::range<3>(1, 1, SYCL_GLU_BLOCK_SIZE)), - [=](sycl::nd_item<3> item_ct1) { + [=](sycl::nd_item<3> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { swiglu_oai_kernel(x, g, dst, k, n, o0, o1, alpha, limit, item_ct1); }); } @@ -899,7 +999,8 @@ static inline void ggml_sycl_op_geglu_erf(ggml_backend_sycl_context & ctx, ggml_ [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_erf(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); @@ -910,7 +1011,8 @@ static inline void ggml_sycl_op_geglu_quick(ggml_backend_sycl_context & ctx, ggm [](const auto* x_ptr, const auto* g_ptr, auto* dst_ptr, uint64_t k, uint64_t n, uint64_t o0, uint64_t o1, queue_ptr main_stream) { const uint32_t num_blocks = ceil_div(k, SYCL_GELU_BLOCK_SIZE); main_stream->parallel_for( - sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) { + sycl::nd_range<1>((num_blocks * sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), + sycl::range<1>(SYCL_GELU_BLOCK_SIZE)), [=](sycl::nd_item<1> item_ct1) [[sycl::reqd_sub_group_size(WARP_SIZE)]] { gated_op_fused_geglu_quick(x_ptr, g_ptr, dst_ptr, k, n, o0, o1, item_ct1); }); }); diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index d8b83d0e23..41449db665 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -5859,6 +5859,250 @@ static ggml_backend_dev_t ggml_backend_sycl_reg_get_device(ggml_backend_reg_t re return ctx->devices[index]; } +// ========================================================================== +// Tensor parallelism (--split-mode tensor) for the SYCL backend. +// +// The meta-backend invokes these three entry points via get_proc_address: +// * ggml_backend_sycl_comm_init - one-time per-graph setup +// * ggml_backend_sycl_comm_allreduce_tensor - per-allreduce step +// * ggml_backend_sycl_comm_free - tear-down +// +// For N=2 (dual-GPU), this is a degenerate ring allreduce with dual paths +// chosen by tensor size: +// +// * Small (nelem < 32K): FP32 direct memcpy + per-device ADD +// kernel. The kernel depends_on() its corresponding memcpy event +// so it doesn't read partial data. Both devices run in parallel. +// +// * Large (nelem >= 32K): BF16-compressed. Each device compresses +// its FP32 partial to BF16 locally, cross-device memcpys +// to the peer (half the PCI bandwidth), where it is decompressed +// and added into the local FP32 partial. 6 SYCL submissions per +// allreduce (2 compress + 2 memcpy + 2 decompress-add) vs the +// 4 for the small path, but the bandwidth saving > 6 GB/s PCIe x 2 +// dominates for larger tensors. +// +// Storage: A persistent uint8_t buffer per device, sized to +// 4 * nelem bytes. Both paths reinterpret the same bytes (small path +// as nelem floats; large path as outbox + inbox = 2*nelem uint16_t +// each, using the full 4*nelem byte budget either way). Single +// alloc+free per device keeps the SYCL pool's strict-LIFO invariant +// trivial. +// +// For non-(N=2 FP32 contiguous) cases, comm_init or comm_allreduce_tensor +// returns null/false, causing the meta-backend to use its generic +// butterfly all-reduce fallback. +// ========================================================================== + +struct ggml_backend_sycl_comm_context { + std::vector backends; + // ONE persistent per-device byte buffer, 4*nelem bytes. Both the + // FP32 small-tensor path and the BF16 large-tensor path share it + // by reinterpreting. + std::unique_ptr> buf0; + std::unique_ptr> buf1; + int64_t buf_nelem = 0; +}; + +void * ggml_backend_sycl_comm_init(ggml_backend_t * backends, size_t n_backends) try { + for (size_t i = 0; i < n_backends; ++i) { + if (!ggml_backend_is_sycl(backends[i])) { + return nullptr; + } + } + + // Initial version: N=2 only. For N!=2, returning null makes the + // meta-backend skip this backend-specific allreduce entirely. + if (n_backends != 2) { + return nullptr; + } + + auto * ctx = new ggml_backend_sycl_comm_context; + ctx->backends.assign(backends, backends + n_backends); + auto * sctx0 = (ggml_backend_sycl_context *) backends[0]->context; + auto * sctx1 = (ggml_backend_sycl_context *) backends[1]->context; + ctx->buf0 = std::make_unique>(sctx0->pool()); + ctx->buf1 = std::make_unique>(sctx1->pool()); + return ctx; +} +catch (const sycl::exception &) { return nullptr; } +catch (...) { return nullptr; } + +void ggml_backend_sycl_comm_free(void * comm_ctx_v) { + auto * comm_ctx = static_cast(comm_ctx_v); + if (comm_ctx == nullptr) { + return; + } + + // Sync both per-device queues so the pool_alloc destructors don't + // return memory still in use by the last kernel. + if (comm_ctx->backends.size() == 2) { + auto * sctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context; + auto * sctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context; + try { + sctx0->stream()->wait(); + sctx1->stream()->wait(); + } catch (...) { /* best effort during shutdown */ } + } + + delete comm_ctx; +} + +bool ggml_backend_sycl_comm_allreduce_tensor(void * comm_ctx_v, struct ggml_tensor ** tensors) try { + if (comm_ctx_v == nullptr) { + return false; + } + + auto * comm_ctx = static_cast(comm_ctx_v); + const size_t n_backends = comm_ctx->backends.size(); + + // Fast path: N=2, F32/F16, contiguous, matching shapes. + if (n_backends != 2) { + return false; + } + // Accept F32 or F16 inputs natively (types must match). F16 takes the + // direct 2-byte memcpy + add path below; other types return false so the + // meta-backend uses its generic all-reduce. + if (tensors[0]->type != tensors[1]->type) { + return false; + } + if (tensors[0]->type != GGML_TYPE_F32 && tensors[0]->type != GGML_TYPE_F16) { + return false; + } + if (!ggml_is_contiguous(tensors[0]) || !ggml_is_contiguous(tensors[1])) { + return false; + } + if (ggml_nelements(tensors[0]) != ggml_nelements(tensors[1])) { + return false; + } + + const int64_t nelem = ggml_nelements(tensors[0]); + const size_t nbytes = ggml_nbytes(tensors[0]); + if (nelem == 0) { + return true; + } + + auto * ctx0 = (ggml_backend_sycl_context *) comm_ctx->backends[0]->context; + auto * ctx1 = (ggml_backend_sycl_context *) comm_ctx->backends[1]->context; + queue_ptr q0 = ctx0->stream(); + queue_ptr q1 = ctx1->stream(); + + // Grow per-device byte buffers if needed (4 * nelem bytes each). + if (comm_ctx->buf_nelem < nelem) { + comm_ctx->buf0->realloc(nelem * 4); + comm_ctx->buf1->realloc(nelem * 4); + comm_ctx->buf_nelem = nelem; + } + uint8_t * buf0 = comm_ctx->buf0->get(); + uint8_t * buf1 = comm_ctx->buf1->get(); + + // F16 native path: direct 2-byte cross-device copy + add, skipping the + // F32 round-trip the meta-backend fallback would force. Cross-device copies + // go through dev2dev_memcpy because the two devices are in separate SYCL + // contexts (a raw peer-USM q->memcpy would be a silent no-op). + if (tensors[0]->type == GGML_TYPE_F16) { + sycl::half * f16_out0 = (sycl::half *) tensors[0]->data; + sycl::half * f16_out1 = (sycl::half *) tensors[1]->data; + sycl::half * f16_tmp0 = (sycl::half *) buf0; + sycl::half * f16_tmp1 = (sycl::half *) buf1; + + q0->wait(); + q1->wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, f16_tmp0, tensors[1]->data, nbytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, f16_tmp1, tensors[0]->data, nbytes); + + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + f16_out0[i] = (sycl::half) ((float) f16_out0[i] + (float) f16_tmp0[i]); + }); + }); + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + f16_out1[i] = (sycl::half) ((float) f16_out1[i] + (float) f16_tmp1[i]); + }); + }); + return true; + } + + float * out0 = (float *) tensors[0]->data; + float * out1 = (float *) tensors[1]->data; + + // BF16 threshold: above this, the PCIe savings from halving the + // cross-device bytes outweigh the 2 extra compress kernels. + // Below: stay on the FP32 fast path. Threshold mirrors the CUDA + // NCCL allreduce pattern for n_backends=2. + static constexpr int64_t BF16_THRESHOLD = 32768; + + if (nelem < BF16_THRESHOLD) { + // FP32 small path: 4 SYCL submissions per allreduce. + float * tmp0 = (float *) buf0; + float * tmp1 = (float *) buf1; + + // COMM-D2D-FIX: the two devices are in SEPARATE SYCL contexts, so a raw + // q->memcpy of a peer USM pointer is a silent no-op. Route cross-device + // copies through dev2dev_memcpy (L0 direct copy / host staging). It is + // synchronous, so wait for the local partials to be produced first. + q0->wait(); + q1->wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, tmp0, tensors[1]->data, nbytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, tmp1, tensors[0]->data, nbytes); + + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out0[i] += tmp0[i]; + }); + }); + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out1[i] += tmp1[i]; + }); + }); + return true; + } + + // BF16 large path: 6 SYCL submissions per allreduce, but the + // cross-device memcpy is HALF the bytes. Pure bit-shift + // conversion (no rounding) โ€” matches ggml's truncating fp32->bf16. + uint16_t * outbox0 = (uint16_t *) buf0; + uint16_t * inbox0 = outbox0 + nelem; + uint16_t * outbox1 = (uint16_t *) buf1; + uint16_t * inbox1 = outbox1 + nelem; + + // Phase A: compress each device's local partial in parallel. + sycl::event c0 = q0->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + outbox0[i] = (uint16_t) (sycl::bit_cast(out0[i]) >> 16); + }); + + sycl::event c1 = q1->parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + outbox1[i] = (uint16_t) (sycl::bit_cast(out1[i]) >> 16); + }); + + // Phase B: COMM-D2D-FIX-BF16 cross-device copy of compressed bytes via + // dev2dev_memcpy (separate SYCL contexts; sync copy after compress). + const size_t bf16_bytes = nelem * sizeof(uint16_t); + c0.wait(); + c1.wait(); + dev2dev_memcpy(ctx0->device, *q0, ctx1->device, *q1, inbox0, outbox1, bf16_bytes); + dev2dev_memcpy(ctx1->device, *q1, ctx0->device, *q0, inbox1, outbox0, bf16_bytes); + + // Phase C: decompress + add into local FP32 partial. + q0->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out0[i] += sycl::bit_cast(((uint32_t) inbox0[i]) << 16); + }); + }); + + q1->submit([&](sycl::handler & h) { + h.parallel_for(sycl::range<1>(nelem), [=](sycl::id<1> i) { + out1[i] += sycl::bit_cast(((uint32_t) inbox1[i]) << 16); + }); + }); + + return true; +} +catch (const sycl::exception &) { return false; } +catch (...) { return false; } + static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, const char *name) { GGML_UNUSED(reg); @@ -5866,6 +6110,17 @@ static void *ggml_backend_sycl_reg_get_proc_address(ggml_backend_reg_t reg, cons return (void *)ggml_backend_sycl_split_buffer_type; } + // Tensor parallelism (--split-mode tensor) entry points. + if (strcmp(name, "ggml_backend_comm_init") == 0) { + return (void *)ggml_backend_sycl_comm_init; + } + if (strcmp(name, "ggml_backend_comm_free") == 0) { + return (void *)ggml_backend_sycl_comm_free; + } + if (strcmp(name, "ggml_backend_comm_allreduce_tensor") == 0) { + return (void *)ggml_backend_sycl_comm_allreduce_tensor; + } + // SYCL doesn't support registering host memory, left here for reference // "ggml_backend_register_host_buffer" // "ggml_backend_unregister_host_buffer" diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 2d9e85794a..5aeb6e97b1 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -108,6 +108,9 @@ if (Vulkan_FOUND) if (GGML_VULKAN_CHECK_RESULTS) add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + # the result-checking path computes a CPU reference graph via + # ggml_graph_compute_with_ctx(), which is defined in ggml-cpu + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() if (GGML_VULKAN_DEBUG) @@ -129,6 +132,8 @@ if (Vulkan_FOUND) if (GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS) + # the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu) + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() # Set up toolchain for host compilation whether cross-compiling or not diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 9a36b45de8..5fbebc6d75 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -493,6 +493,20 @@ struct vk_conv2d_pipeline_state { } }; +struct vk_conv3d_pipeline_state { + vk_conv3d_pipeline_state(uint32_t s0, uint32_t s1, uint32_t s2, uint32_t p0, uint32_t p1, uint32_t p2, + uint32_t d0, uint32_t d1, uint32_t d2, uint32_t KW, uint32_t KH, uint32_t KD, uint32_t aligned) + : s0(s0), s1(s1), s2(s2), p0(p0), p1(p1), p2(p2), d0(d0), d1(d1), d2(d2), KW(KW), KH(KH), KD(KD), aligned(aligned) {} + + uint32_t s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD; + uint32_t aligned; + + bool operator<(const vk_conv3d_pipeline_state &b) const { + return std::tie(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned) < + std::tie(b.s0, b.s1, b.s2, b.p0, b.p1, b.p2, b.d0, b.d1, b.d2, b.KW, b.KH, b.KD, b.aligned); + } +}; + struct vk_solve_tri_pipeline_state { vk_solve_tri_pipeline_state(uint32_t N, uint32_t K) : N(N), K(K) {} @@ -685,6 +699,7 @@ struct vk_device_struct { bool add_rms_fusion; uint32_t partials_binding_alignment; + uint32_t max_nodes_per_submit; bool shader_64b_indexing; @@ -777,6 +792,7 @@ struct vk_device_struct { vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_back_f32; vk_pipeline pipeline_acc_f32; vk_pipeline pipeline_set_f32; @@ -801,14 +817,10 @@ struct vk_device_struct { vk_pipeline pipeline_concat_i8, pipeline_concat_i16, pipeline_concat_i32, pipeline_concat_i64; vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bicubic_f32, pipeline_upscale_bilinear_antialias_f32; vk_pipeline pipeline_scale_f32; - vk_pipeline pipeline_sqr_f32; - vk_pipeline pipeline_sqrt_f32; - vk_pipeline pipeline_sin_f32; - vk_pipeline pipeline_cos_f32; vk_pipeline pipeline_log[2]; vk_pipeline pipeline_tri[2]; vk_pipeline pipeline_diag[2]; - vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_clamp[2]; vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_repeat_i32, pipeline_repeat_back_f32; @@ -840,6 +852,10 @@ struct vk_device_struct { vk_pipeline pipeline_gelu_quick[2]; vk_pipeline pipeline_silu[2]; vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_sqr[2]; + vk_pipeline pipeline_sqrt[2]; + vk_pipeline pipeline_sin[2]; + vk_pipeline pipeline_cos[2]; vk_pipeline pipeline_xielu[2]; vk_pipeline pipeline_neg[2]; vk_pipeline pipeline_tanh[2]; @@ -871,7 +887,7 @@ struct vk_device_struct { vk_pipeline pipeline_geglu_erf[2]; vk_pipeline pipeline_geglu_quick[2]; - vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_leaky_relu[2]; vk_pipeline pipeline_silu_back_f32; vk_pipeline pipeline_diag_mask_inf_f32; vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; @@ -924,6 +940,8 @@ struct vk_device_struct { std::map pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; std::map pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f32[CONV_SHAPE_COUNT]; + std::map pipeline_conv3d_f16_f32[CONV_SHAPE_COUNT]; vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; @@ -1669,6 +1687,41 @@ template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); } +struct vk_op_conv3d_push_constants { + uint32_t OC; + uint32_t IC; + uint32_t N; + + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv3d_push_constants &p) { + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.OW*p.OH*p.OD, p.OWOHODmp, p.OWOHODL); +} + struct vk_op_conv2d_dw_push_constants { uint32_t ne; uint32_t batches; @@ -4074,19 +4127,35 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #endif + auto const &ggml_vk_mul_mm_spec = [](std::vector spec, bool aligned) { + spec.push_back(aligned ? 1u : 0u); + return spec; + }; + const int mul_mat_id_param_count = 5; #if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (device->coopmat2) { + auto const &ggml_vk_mul_mm_cm2_spec = [](std::vector spec, bool aligned, bool mul_mat_id) { + if (mul_mat_id && spec.size() > 5) { + spec.insert(spec.begin() + 5, aligned ? 1u : 0u); + } else { + spec.push_back(aligned ? 1u : 0u); + } + if (mul_mat_id && spec.size() == 6) { + spec.push_back(32); + } + return spec; + }; // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, true); \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, false, PARAMCOUNT == mul_mat_id_param_count), 1, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(l_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), l_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(m_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), m_align, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_cm2_spec(s_ ## WARPTILE, true, PARAMCOUNT == mul_mat_id_param_count), s_align, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ @@ -4161,17 +4230,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, true); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, true); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, true); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, true); \ // Create 2 variants, {f16,f32} accumulator #define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ @@ -4284,32 +4353,32 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Selects dot2 SPIR-V variant at runtime when device->dot2_f16 is true #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _len : NAMELC ## _aligned ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2_aligned ## F16ACC ## _data : NAMELC ## _aligned ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _len : NAMELC ## F16ACC ## _len), (device->dot2_f16 ? NAMELC ## _dot2 ## F16ACC ## _data : NAMELC ## F16ACC ## _data), "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ // bf16 scalar path promotes to f32, no dot2 variant #define CREATE_MM_NODOT2(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l_int[TYPE]) { \ @@ -4474,17 +4543,17 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { // Create 6 variants, {s,m,l}x{unaligned,aligned} #define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, false), 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _l[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, ggml_vk_mul_mm_spec(l_ ## WARPTILE, true), l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _m[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, ggml_vk_mul_mm_spec(m_ ## WARPTILE, true), m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ if (device->mul_mat ## ID ## _s[TYPE]) \ - ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, ggml_vk_mul_mm_spec(s_ ## WARPTILE, true), s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ #define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ if (device->mul_mat ## ID ## _l_int[TYPE]) \ @@ -4879,6 +4948,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_NVFP4], "get_rows_nvfp4_f32", get_rows_nvfp4_f32_len, get_rows_nvfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_back_f32, "get_rows_back_f32", get_rows_back_f32_len, get_rows_back_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {256, 1, 1}, {}, 1, true); ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, sizeof(vk_op_flash_attn_split_k_reduce_push_constants), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); @@ -4903,7 +4973,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_nc_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); @@ -5023,11 +5093,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5037,8 +5102,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_diag[0], "diag_f32", diag_f32_len, diag_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag[1], "diag_f16", diag_f16_len, diag_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -5058,6 +5121,12 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_UNARY(gelu_quick) CREATE_UNARY(silu) CREATE_UNARY(relu) + CREATE_UNARY(sqr) + CREATE_UNARY(sqrt) + CREATE_UNARY(sin) + CREATE_UNARY(cos) + CREATE_UNARY(clamp) + CREATE_UNARY(leaky_relu) CREATE_UNARY(xielu) CREATE_UNARY(neg) CREATE_UNARY(tanh) @@ -5097,7 +5166,6 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { CREATE_GLU(geglu_quick) #undef CREATE_GLU - ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); @@ -5314,7 +5382,7 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); - // conv2d, conv_transpose_2d + // conv2d, conv_transpose_2d, conv3d for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { // smaller WG for the small-tile fallback gives more concurrent WGs per SM uint32_t conv2d_WG_SIZE = (s == CONV_SHAPE_64x32) ? 128 : 256; @@ -5377,8 +5445,8 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { return (conv2d_BS.K * (conv2d_BS.CRS + pad) + conv2d_BS.CRS * (conv2d_BS.NPQ + pad) + csh_elems) * elem_size; }; - // coopmat1 needs to store the output through shared memory, so check up front - // whether it'll fit and disable it before applying coopmat1 parameters. + // 2D, transpose-2D, and 3D conv use the same KxCRS @ CRSxNPQ shmem + // layout. cm1 needs Csh for output, so check before applying cm1 params. if (conv2d_use_cm1 && device->properties.limits.maxComputeSharedMemorySize < shmem_req(conv2d_cm1_shmem_pad, true, true)) { conv2d_use_cm1 = false; } @@ -5470,6 +5538,53 @@ static void ggml_vk_load_shaders(vk_device& device, vk_pipeline requested) { } #undef CREATE_CONV #undef CREATE_CONVS + + std::vector conv3d_spec_constants = { conv2d_WG_SIZE, conv2d_BS.K, conv2d_BS.CRS, conv2d_BS.NPQ, conv2d_TS_K, conv2d_SHMEM_PAD }; +#define CREATE_CONV3D(type_suffix, spv_suffix) \ + for (auto &c : device->pipeline_conv3d##type_suffix[s]) { \ + const vk_conv3d_pipeline_state &state = c.first; \ + std::vector spec_constants_cpy = conv3d_spec_constants; \ + spec_constants_cpy.push_back(state.s0); \ + spec_constants_cpy.push_back(state.s1); \ + spec_constants_cpy.push_back(state.s2); \ + spec_constants_cpy.push_back(state.p0); \ + spec_constants_cpy.push_back(state.p1); \ + spec_constants_cpy.push_back(state.p2); \ + spec_constants_cpy.push_back(state.d0); \ + spec_constants_cpy.push_back(state.d1); \ + spec_constants_cpy.push_back(state.d2); \ + spec_constants_cpy.push_back(state.KW); \ + spec_constants_cpy.push_back(state.KH); \ + spec_constants_cpy.push_back(state.KD); \ + spec_constants_cpy.push_back(state.aligned); \ + spec_constants_cpy.push_back(conv2d_csh_store); \ + spec_constants_cpy.push_back(conv2d_WM); \ + spec_constants_cpy.push_back(conv2d_WN); \ + ggml_vk_create_pipeline( \ + device, c.second, "conv3d" #type_suffix, \ + conv3d##type_suffix##spv_suffix##_len, conv3d##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_conv3d_push_constants), wg_denoms, spec_constants_cpy, 1, true, conv2d_required_subgroup_size != 0, conv2d_required_subgroup_size); \ + } +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_CONV3D(_f32, _cm2) + CREATE_CONV3D(_f16_f32, _cm2) + } else +#endif +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (conv2d_use_cm1) { + CREATE_CONV3D(_f32, _cm1) + CREATE_CONV3D(_f16_f32, _cm1) + } else +#endif + if (conv2d_UNROLL) { + CREATE_CONV3D(_f32, _unroll) + CREATE_CONV3D(_f16_f32, _unroll) + } else { + CREATE_CONV3D(_f32, ) + CREATE_CONV3D(_f16_f32, ) + } +#undef CREATE_CONV3D } ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); @@ -5764,6 +5879,14 @@ static vk_device ggml_vk_get_device(size_t idx) { device->subgroup_vote = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eVote); + // Submit at least every 100 nodes, in case there are workloads without as much matmul. + device->max_nodes_per_submit = 100; + const char* GGML_VK_MAX_NODES_PER_SUBMIT = getenv("GGML_VK_MAX_NODES_PER_SUBMIT"); + if (GGML_VK_MAX_NODES_PER_SUBMIT != nullptr) { + uint32_t max_nodes_per_submit = std::stoul(GGML_VK_MAX_NODES_PER_SUBMIT); + device->max_nodes_per_submit = std::max(max_nodes_per_submit, 1u); + } + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; @@ -10294,6 +10417,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_get_rows_f32[src0->type]; } return nullptr; + case GGML_OP_GET_ROWS_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_back_f32; + } + return nullptr; case GGML_OP_ACC: if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { return ctx->device->pipeline_acc_f32; @@ -10400,23 +10528,27 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_SQR: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqr_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqr[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SQRT: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sqrt_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sqrt[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_SIN: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_sin_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_sin[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_COS: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cos_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_cos[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_LOG: @@ -10438,8 +10570,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_CLAMP: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_clamp_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_clamp[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_PAD: @@ -10807,8 +10940,9 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } return nullptr; case GGML_OP_LEAKY_RELU: - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_leaky_relu_f32; + if (src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) { + return ctx->device->pipeline_leaky_relu[dst->type == GGML_TYPE_F16]; } return nullptr; case GGML_OP_CONV_2D: @@ -10885,6 +11019,61 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const } } return nullptr; + case GGML_OP_CONV_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + const uint32_t OC = (uint32_t)ggml_get_op_params_i32(dst, 11); + const uint32_t IC = (uint32_t)ggml_get_op_params_i32(dst, 9); + const uint32_t N = (uint32_t)ggml_get_op_params_i32(dst, 10); + const uint32_t NPQ = N * dst->ne[2] * dst->ne[1] * dst->ne[0]; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, OC, NPQ); + + const uint32_t KW = (uint32_t)src0->ne[0]; + const uint32_t KH = (uint32_t)src0->ne[1]; + const uint32_t KD = (uint32_t)src0->ne[2]; + const uint32_t s0 = (uint32_t)ggml_get_op_params_i32(dst, 0); + const uint32_t s1 = (uint32_t)ggml_get_op_params_i32(dst, 1); + const uint32_t s2 = (uint32_t)ggml_get_op_params_i32(dst, 2); + const uint32_t p0 = (uint32_t)ggml_get_op_params_i32(dst, 3); + const uint32_t p1 = (uint32_t)ggml_get_op_params_i32(dst, 4); + const uint32_t p2 = (uint32_t)ggml_get_op_params_i32(dst, 5); + const uint32_t d0 = (uint32_t)ggml_get_op_params_i32(dst, 6); + const uint32_t d1 = (uint32_t)ggml_get_op_params_i32(dst, 7); + const uint32_t d2 = (uint32_t)ggml_get_op_params_i32(dst, 8); + + const uint32_t CRS = IC * KW * KH * KD; + const uint32_t BS_K = vk_conv_block_sizes[shape].K; + const uint32_t BS_CRS = vk_conv_block_sizes[shape].CRS; + const uint32_t BS_NPQ = vk_conv_block_sizes[shape].NPQ; + const uint32_t aligned = ((OC % BS_K == 0) && + (CRS % BS_CRS == 0) && + (NPQ % BS_NPQ == 0)) ? 1u : 0u; + + vk_conv3d_pipeline_state conv3d_pipeline_state(s0, s1, s2, p0, p1, p2, d0, d1, d2, KW, KH, KD, aligned); + + std::map *pipelines = nullptr; + if (src0->type == GGML_TYPE_F32) { + pipelines = &ctx->device->pipeline_conv3d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + pipelines = &ctx->device->pipeline_conv3d_f16_f32[shape]; + } else { + return nullptr; + } + + vk_pipeline pipeline = nullptr; + + { + std::lock_guard guard(ctx->device->compile_mutex); + auto it = pipelines->find(conv3d_pipeline_state); + if (it != pipelines->end()) { + pipeline = it->second; + } else { + (*pipelines)[conv3d_pipeline_state] = pipeline = std::make_shared(); + } + } + + return pipeline; + } + return nullptr; case GGML_OP_ADD1: if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { return ctx->device->pipeline_add1_f16_f16; @@ -11135,6 +11324,10 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); break; + case GGML_OP_GET_ROWS_BACK: + elements = { (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], 1 }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + break; case GGML_OP_ARGSORT: GGML_ASSERT(0); break; @@ -11220,6 +11413,21 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co GGML_ABORT("invalid push constant type for CONV_2D"); } break; + case GGML_OP_CONV_3D: + if constexpr (std::is_same_v) { + const uint32_t NPQ = pc.N * pc.OD * pc.OH * pc.OW; + const vk_conv_shapes shape = ggml_vk_conv_select_shape(ctx, pc.OC, NPQ); + const uint32_t NPQ_blocks = CEIL_DIV(NPQ, vk_conv_block_sizes[shape].NPQ); + + elements = { pc.OC, NPQ_blocks, 1 }; + if (elements[1] > 512) { + elements[2] = CEIL_DIV(elements[1], 512); + elements[1] = 512; + } + } else { + GGML_ABORT("invalid push constant type for CONV_3D"); + } + break; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_DIV: @@ -11236,6 +11444,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co case GGML_OP_TRI: case GGML_OP_DIAG: case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: case GGML_OP_PAD: case GGML_OP_ROLL: case GGML_OP_REPEAT: @@ -11380,6 +11589,21 @@ static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, }); } +static void ggml_vk_get_rows_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_GET_ROWS_BACK, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2], (uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }); +} + static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const uint32_t src0_type_size = ggml_type_size(src0->type); const uint32_t src1_type_size = ggml_type_size(src1->type); @@ -12087,8 +12311,10 @@ static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { float * op_params = (float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_NORM, std::move(p)); } static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -13118,6 +13344,51 @@ static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, std::move(p)); } +static void ggml_vk_conv_3d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv3d_push_constants p{}; + p.IC = static_cast(ggml_get_op_params_i32(dst, 9)); + p.N = static_cast(ggml_get_op_params_i32(dst, 10)); + p.OC = static_cast(ggml_get_op_params_i32(dst, 11)); + GGML_ASSERT(src0->ne[3] == (int64_t)p.IC * p.OC); + GGML_ASSERT(src1->ne[3] == (int64_t)p.IC * p.N); + GGML_ASSERT(dst->ne[3] == (int64_t)p.OC * p.N); + + p.IW = static_cast(ne10); + p.IH = static_cast(ne11); + p.ID = static_cast(ne12); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + p.OD = static_cast(ne2); + + // the shader clamps src addresses to p.IC * p.N * p.IW * p.IH * p.ID - 1 in uint32, so the + // total input element count must fit in a uint32. + GGML_ASSERT((uint64_t)p.IC * p.N * p.IW * p.IH * p.ID <= 0xFFFFFFFFull); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_CONV_3D, std::move(p)); +} + static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { vk_op_conv2d_dw_push_constants p{}; p.ne = ggml_nelements(dst); @@ -13144,7 +13415,10 @@ static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { const float * op_params = (const float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f, 0.0f, 0.0f }); + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, std::move(p)); } #ifdef GGML_VULKAN_RUN_TESTS @@ -14247,6 +14521,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_GET_ROWS: ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_GET_ROWS_BACK: + ggml_vk_get_rows_back(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_ADD: if (ctx->num_additional_fused_ops) { @@ -14515,6 +14793,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr case GGML_OP_CONV_TRANSPOSE_2D: ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node); + break; + case GGML_OP_CONV_3D: + ggml_vk_conv_3d(ctx, compute_ctx, src0, src1, node); + break; case GGML_OP_CONV_2D_DW: ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node); @@ -15900,8 +16182,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB // (and scaled down based on model size, so smaller models submit earlier). - // Also submit at least every 100 nodes, in case there are workloads without as much matmul. - int nodes_per_submit = 100; int submitted_nodes = 0; int submit_count = 0; uint64_t mul_mat_bytes = 0; @@ -16127,7 +16407,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; - bool submit = (submitted_nodes >= nodes_per_submit) || + bool submit = ((uint32_t)submitted_nodes >= ctx->device->max_nodes_per_submit) || (mul_mat_bytes_per_submit != 0 && mul_mat_bytes >= mul_mat_bytes_per_submit) || (i + ctx->num_additional_fused_ops >= last_node) || (almost_ready && !ctx->almost_ready_fence_pending); @@ -16964,6 +17244,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } } + case GGML_OP_GET_ROWS_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SET_ROWS: { switch (op->type) { @@ -17060,12 +17342,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_TRANSPOSE: case GGML_OP_RMS_NORM: return true; - case GGML_OP_NORM: case GGML_OP_GROUP_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_NORM: case GGML_OP_L2_NORM: - return ggml_is_contiguous_rows(op->src[0]) && - op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -17084,8 +17365,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_LEAKY_RELU: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->type == op->src[0]->type; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; @@ -17285,6 +17567,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm ggml_is_contiguous(op->src[1]) && ggml_is_contiguous(op)); } + case GGML_OP_CONV_3D: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op); default: return false; } @@ -18128,6 +18417,20 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * const int32_t d0 = tensor->op_params[4]; const int32_t d1 = tensor->op_params[5]; tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + const int32_t N = tensor->op_params[10]; + const int32_t OC = tensor->op_params[11]; + tensor_clone = ggml_conv_3d_direct(ggml_ctx, src_clone[0], src_clone[1], s0, s1, s2, p0, p1, p2, d0, d1, d2, IC, N, OC); } else if (tensor->op == GGML_OP_CONV_2D_DW) { const int32_t s0 = tensor->op_params[0]; const int32_t s1 = tensor->op_params[1]; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp deleted file mode 100644 index 653431895e..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp new file mode 100644 index 0000000000..a9712eb3ac --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/conv3d_mm.comp @@ -0,0 +1,431 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#include "types.glsl" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, KD, IC*OC] + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [IW, IH, ID, IC*N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, OD, OC*N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t OC; + uint32_t IC; + uint32_t N; + + // Tensor spatial sizes: input, output + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t OW; + uint32_t OH; + uint32_t OD; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t OWOHODmp; uint32_t OWOHODL; +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint SHMEM_PAD = 4; +// Stride, padding, dilation +layout(constant_id = 6) const uint s0 = 1; +layout(constant_id = 7) const uint s1 = 1; +layout(constant_id = 8) const uint s2 = 1; +layout(constant_id = 9) const uint p0 = 0; +layout(constant_id = 10) const uint p1 = 0; +layout(constant_id = 11) const uint p2 = 0; +layout(constant_id = 12) const uint d0 = 1; +layout(constant_id = 13) const uint d1 = 1; +layout(constant_id = 14) const uint d2 = 1; +// Kernel spatial sizes +layout(constant_id = 15) const uint KW = 1; +layout(constant_id = 16) const uint KH = 1; +layout(constant_id = 17) const uint KD = 1; +// when set, skip bounds checks and address clamps (K/CRS/NPQ are tile-aligned) +layout(constant_id = 18) const uint aligned = 0; +// stage cm2 result through shmem (Csh) for coalesced stores. cm1 always does this. +layout(constant_id = 19) const uint csh_store = 0; + +#ifdef COOPMAT +// cm1 subgroup tile: each subgroup computes a WM x WN region as a grid of +// TM x TN x TK fragments. Requires WM%TM == WN%TN == BS_K%WM == BS_NPQ%WN == +// BS_CRS%TK == 0, and WG_SIZE == (BS_K/WM) * (BS_NPQ/WN) * subgroup_size. +layout(constant_id = 20) const uint WM = 32; +layout(constant_id = 21) const uint WN = 32; +const uint TM = 16; +const uint TN = 16; +const uint TK = 16; +const uint cms_per_row = WM / TM; +const uint cms_per_col = WN / TN; +const uint warps_M = BS_K / WM; +const uint warps_N = BS_NPQ / WN; +#endif + +// without padding, ID_idx/IH_idx/IW_idx are in bounds by construction +const bool dhw_in_bounds = (p0 == 0) && (p1 == 0) && (p2 == 0); + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.OC; +uint32_t CRS = p.IC * KD * KH * KW; +uint32_t NPQ = p.N * p.OD * p.OH * p.OW; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#if defined(COOPMAT2) || defined(COOPMAT) +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +#if defined(COOPMAT2) || defined(COOPMAT) +// stage matC through shmem so global stores are row-major (NPQ-contiguous) +const uint32_t Csh_stride = BS_NPQ; +#ifdef COOPMAT +const uint32_t Csh_len = BS_K * Csh_stride; +#else +const uint32_t Csh_len = csh_store != 0 ? BS_K * Csh_stride : 1; +#endif +shared SHMEM_TYPE Csh[Csh_len]; // K x NPQ +#endif + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=OC +C=IC +D,R,S=KD,KH,KW +Z,P,Q=OD,OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y + gl_WorkGroupID.z * 512; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +void split_crs(uint32_t crs_idx, out uint32_t ic, out uint32_t kd, out uint32_t kh, out uint32_t kw) { + const uint32_t KHKW = KH * KW; + const uint32_t KDKHKW = KD * KHKW; + ic = crs_idx / KDKHKW; + uint32_t rem = crs_idx - ic * KDKHKW; + kd = rem / KHKW; + rem = rem - kd * KHKW; + kh = rem / KW; + kw = rem - kh * KW; +} + +void split_npq(uint32_t npq_idx, out uint32_t n, out uint32_t od, out uint32_t oh, out uint32_t ow) { + const uint32_t OWOH = p.OW * p.OH; + n = fastdiv(npq_idx, p.OWOHODmp, p.OWOHODL); + uint32_t rem = npq_idx - n * p.OD * OWOH; + od = fastdiv(rem, p.OWOHmp, p.OWOHL); + rem = rem - od * OWOH; + oh = fastdiv(rem, p.OWmp, p.OWL); + ow = rem - oh * p.OW; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { + if (B_idx_NPQ * BS_NPQ >= NPQ) { + return; + } + +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#elif defined(COOPMAT) + coopmat sums[cms_per_row * cms_per_col]; + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0); + } + const uint warp_r = gl_SubgroupID / warps_N; + const uint warp_c = gl_SubgroupID % warps_N; +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + [[dont_unroll]] for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a = B_idx_CRS * BS_CRS + Ac; + uint32_t IC_idx_a; + uint32_t KD_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + split_crs(CRS_idx_a, IC_idx_a, KD_idx_a, KH_idx_a, KW_idx_a); + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ + uint32_t knl_idx = KW_idx_a + KH_idx_a * p.nb01 + KD_idx_a * p.nb02 + (K_idx * p.IC + IC_idx_a) * p.nb03; + if (aligned == 0) { + knl_idx = min(knl_idx, K * CRS - 1); + } + float val = knl_data[knl_idx]; + if (aligned == 0 && (K_idx >= K || CRS_idx_a >= CRS)) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + + uint32_t CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; + uint32_t IC_idx_b; + uint32_t KD_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; + split_crs(CRS_idx_b, IC_idx_b, KD_idx_b, KH_idx_b, KW_idx_b); + + uint32_t ID_idx = OD_idx * s2 + KD_idx_b * d2 - p2; + uint32_t IH_idx = OH_idx * s1 + KH_idx_b * d1 - p1; + uint32_t IW_idx = OW_idx * s0 + KW_idx_b * d0 - p0; + + uint32_t src_idx = IW_idx + IH_idx * p.nb11 + ID_idx * p.nb12 + (N_idx * p.IC + IC_idx_b) * p.nb13; + // skip clamp when address can't go OOB + if (aligned == 0 || !dhw_in_bounds) { + src_idx = min(src_idx, p.IC * p.N * p.IW * p.IH * p.ID - 1); + } + float val = src_data[src_idx]; + bool oob = false; + if (aligned == 0 && (CRS_idx_b >= CRS || NPQ_idx >= NPQ)) { + oob = true; + } + // also catches lower-bound underflow (idx wraps to 0x80000000+) + if (!dhw_in_bounds && (ID_idx >= p.ID || IH_idx >= p.IH || IW_idx >= p.IW)) { + oob = true; + } + if (oob) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#elif defined(COOPMAT) + // each subgroup multiplies its grid of fragments per TK-sized CRS chunk + [[unroll]] for (uint k_step = 0; k_step < BS_CRS / TK; k_step++) { + coopmat cache_a[cms_per_row]; + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint a_off = (warp_r * WM + cm_row * TM) * Ash_stride + k_step * TK; + coopMatLoad(cache_a[cm_row], Ash, a_off, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + } + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopmat cache_b; + const uint b_off = k_step * TK * Bsh_stride + warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, Bsh, b_off, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a[cm_row], cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#if defined(COOPMAT2) || defined(COOPMAT) + // stage matC into Csh, then write to dst with coalesced NPQ-contiguous stores +#ifdef COOPMAT + const bool use_staged_store = true; +#else + const bool use_staged_store = (csh_store != 0); +#endif + if (use_staged_store) { +#ifdef COOPMAT + // cm1: each subgroup stores its fragment grid into its Csh slot + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint csh_off = (warp_r * WM + cm_row * TM) * Csh_stride + warp_c * WN + cm_col * TN; + coopMatStore(sums[cm_col * cms_per_row + cm_row], Csh, csh_off, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } +#else + coopMatStore(matC, Csh, 0, Csh_stride, gl_CooperativeMatrixLayoutRowMajor); +#endif + barrier(); + + // cooperative shmem->global: WG threads spread across BS_NPQ (the + // contiguous direction of dst), each iter covers store_rows_per_iter K-rows + const uint32_t store_rows_per_iter = WG_SIZE / BS_NPQ; + const uint32_t store_iters = BS_K / store_rows_per_iter; + const uint32_t k_thread_offset = tid / BS_NPQ; + const uint32_t npq_thread = tid % BS_NPQ; + [[unroll]] for (uint32_t i = 0; i < store_iters; i++) { + uint32_t k_local = i * store_rows_per_iter + k_thread_offset; + uint32_t K_idx = B_idx_K * BS_K + k_local; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + npq_thread; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(Csh[k_local * Csh_stride + npq_thread]); + } + } + } +#ifdef COOPMAT2 + else { + coopMatPerElementNV(matC, matC, perElemOpStore); + } +#endif +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx; + uint32_t OD_idx; + uint32_t OH_idx; + uint32_t OW_idx; + split_npq(NPQ_idx, N_idx, OD_idx, OH_idx, OW_idx); + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + OD_idx * p.nb2 + (N_idx * p.OC + K_idx) * p.nb3; + if (aligned != 0 || (K_idx < K && NPQ_idx < NPQ)) { + dst_data[dst_idx] = D_TYPE(regC[T_ly][T_lx]); + } + } + } + } +#endif +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp deleted file mode 100644 index db6865db98..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 91fb07c93e..3192130ccf 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -463,6 +463,7 @@ void main() { } rowmaxf = max(rowmaxf, float(Sf[r][c])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // M = max(rowmax, Mold) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 23ae3833e5..16178e5770 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -352,6 +352,7 @@ void main() { } rowmaxf = max(rowmaxf, float(sfsh[r_vec + (c * cols_per_iter + col_tid) * sfshstride][r_comp])); } + rowmaxf += FATTN_KQ_MAX_OFFSET; float Moldf = Mf[r]; // Compute max across the row diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp new file mode 100644 index 0000000000..7e3d8a2819 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_back.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint col = gl_GlobalInvocationID.x; + + if (col >= p.ne20) { + return; + } + + for (uint row = gl_GlobalInvocationID.y; row < p.ne21; row += gl_WorkGroupSize.y * gl_NumWorkGroups.y) { + float sum = 0.0f; + for (uint i = 0; i < p.ne10; ++i) { + if (data_b[get_boffset() + i*p.nb10] == int(row)) { + sum += data_a[get_aoffset() + i*p.nb01 + col*p.nb00]; + } + } + + data_d[get_doffset() + row*p.nb21 + col*p.nb20] = sum; + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index f9af46744d..9039ed1ded 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -14,16 +14,13 @@ void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; - const uint i3 = row / (p.ne11 * p.ne12); - const uint i3_offset = i3 * p.ne12 * p.ne11; - const uint i2 = (row - i3_offset) / p.ne11; - const uint i2_offset = i2 * p.ne11; - const uint i1 = row - i3_offset - i2_offset; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[a_base + i0*p.nb00]); sum[tid] += xi * xi; } @@ -39,6 +36,6 @@ void main() { const FLOAT_TYPE scale = 1.0f / max(sqrt(sum[0]), FLOAT_TYPE(p.param1)); [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { - data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); + data_d[d_base + i0*p.nb10] = D_TYPE(scale * FLOAT_TYPE(data_a[a_base + i0*p.nb00])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp deleted file mode 100644 index b281e855cb..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp +++ /dev/null @@ -1,22 +0,0 @@ -#version 450 - -#include "generic_head.glsl" -#include "types.glsl" - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - -void main() { - const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; - - if (i >= p.KX) { - return; - } - - const float val = float(data_a[i]); - data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index f39410d74f..57c0410e45 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -38,17 +38,7 @@ #define LOAD_VEC_B 1 #endif -// Load 2 values at once without affecting index calculations through LOAD_VEC -#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) -#define LOAD_VEC_BATCH_A 2 -#else -#define LOAD_VEC_BATCH_A 1 -#endif -#if !defined(ALIGNED) -#define LOAD_VEC_BATCH_B 2 -#else -#define LOAD_VEC_BATCH_B 1 -#endif +layout (constant_id = 11) const uint ALIGNED = 0; #if !defined(TO_FLOAT_TYPE) #define TO_FLOAT_TYPE FLOAT_TYPE @@ -57,6 +47,13 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(DATA_A_F32) +layout (binding = 0) readonly buffer A_SCALAR {float data_a_scalar[];}; +#elif defined(DATA_A_F16) +layout (binding = 0) readonly buffer A_SCALAR {float16_t data_a_scalar[];}; +#elif defined(DATA_A_BF16) +layout (binding = 0) readonly buffer A_SCALAR {uint16_t data_a_scalar[];}; +#endif #if defined(A_TYPE_PACKED16) layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; #endif @@ -65,6 +62,7 @@ layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32 #endif layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 1) readonly buffer B_SCALAR {B_TYPE_SCALAR data_b_scalar[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; #ifdef MUL_MAT_ID @@ -194,13 +192,23 @@ void main() { const uint warp_r = warp_i % (BM / WM); const uint warp_c = warp_i / (BM / WM); - const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); - const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); - const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) + const uint LOAD_VEC_A_EFF = (ALIGNED != 0) ? LOAD_VEC_A : 1; + const uint LOAD_VEC_BATCH_A = (ALIGNED != 0) ? 1 : 2; +#else + const uint LOAD_VEC_A_EFF = LOAD_VEC_A; + const uint LOAD_VEC_BATCH_A = 1; +#endif + const uint LOAD_VEC_B_EFF = (ALIGNED != 0) ? LOAD_VEC_B : 1; + const uint LOAD_VEC_BATCH_B = (ALIGNED != 0) ? 1 : 2; - const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; - const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A_EFF / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B_EFF / LOAD_VEC_BATCH_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A_EFF * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B_EFF * LOAD_VEC_BATCH_B / BK; #ifdef MUL_MAT_ID #ifdef MUL_MAT_ID_USE_SUBGROUPS @@ -239,15 +247,15 @@ void main() { uint pos_a = #ifdef MUL_MAT_ID - expert_idx * (p.batch_stride_a / LOAD_VEC_A) + + expert_idx * (p.batch_stride_a / LOAD_VEC_A_EFF) + #else - batch_idx_a * (p.batch_stride_a / LOAD_VEC_A) + + batch_idx_a * (p.batch_stride_a / LOAD_VEC_A_EFF) + #endif - (ir * BM * p.stride_a + start_k) / LOAD_VEC_A; + (ir * BM * p.stride_a + start_k) / LOAD_VEC_A_EFF; #ifdef MUL_MAT_ID uint pos_b = 0; #else - uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B_EFF; #endif #ifdef COOPMAT @@ -287,8 +295,8 @@ void main() { barrier(); - pos_a += BK / LOAD_VEC_A; - pos_b += BK / LOAD_VEC_B; + pos_a += BK / LOAD_VEC_A_EFF; + pos_b += BK / LOAD_VEC_B_EFF; #ifdef COOPMAT [[unroll]] for (uint i = 0; i < BK; i += TK) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 2656fe1c3e..a2e15f6f5c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -36,6 +36,7 @@ layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working wit layout (constant_id = 4) const bool enable_smaller_matrices = false; const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; +layout (constant_id = 5) const uint ALIGNED = 0; layout (push_constant) uniform parameter { @@ -111,7 +112,7 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { }; uint _ne1; -layout (constant_id = 5) const uint subgroup_size = 32; +layout (constant_id = 6) const uint subgroup_size = 32; shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) @@ -297,12 +298,12 @@ void main() { // Hint to the compiler that values are aligned (want 16B alignment). // Quants are always block-aligned, no alignment needed. -#if ALIGNED + if (ALIGNED != 0) { #if QUANT_K == 1 - stride_a &= ~7; -#endif - stride_b &= ~7; + stride_a &= ~7; #endif + stride_b &= ~7; + } // Create layouts for both clamped and unclamped accesses tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 7359516898..56a8a0f187 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -1,50 +1,57 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { #if defined(DATA_A_F32) || defined(DATA_A_F16) #if LOAD_VEC_A == 8 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); - buf_a[buf_idx ] = aa[0].xy; - buf_a[buf_idx + 1] = aa[0].zw; - buf_a[buf_idx + 2] = aa[1].xy; - buf_a[buf_idx + 3] = aa[1].zw; + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV8 aa = FLOAT_TYPEV8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; + return; + } #elif LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], - data_a[idx + 1]); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], + data_a_scalar[idx + 1]); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(data_a[idx], 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(data_a_scalar[idx], 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_BF16) #if LOAD_VEC_A == 4 - const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); - buf_a[buf_idx ] = aa.xy; - buf_a[buf_idx + 1] = aa.zw; -#else // LOAD_VEC_BATCH_A == 2 + if (ALIGNED != 0) { + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPEV4 aa = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; + return; + } +#endif const uint idx = pos_a + col * p.stride_a + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_m < p.M && block + row * 2 + 1 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), - TO_FLOAT_TYPE(data_a[idx + 1])); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), + TO_FLOAT_TYPE(data_a_scalar[idx + 1])); } else if (idx_m < p.M && block + row * 2 < end_k) { - buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + buf_a[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_a_scalar[idx]), 0.0f); } else { buf_a[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; @@ -526,75 +533,85 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #if !defined(MUL_MAT_ID) void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint idx = pos_b + col * p.stride_b + row * 2; const uint buf_idx = col * SHMEM_STRIDE + row; if (idx_n < p.N && block + row * 2 + 1 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (idx_n < p.N && block + row * 2 < end_k) { - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #else void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { #if LOAD_VEC_B == 8 - // Not supported for b_type bf16 because bf16mat2x4 does not exist - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; - FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); - buf_b[buf_idx + 0] = bb[0].xy; - buf_b[buf_idx + 1] = bb[0].zw; - buf_b[buf_idx + 2] = bb[1].xy; - buf_b[buf_idx + 3] = bb[1].zw; + if (ALIGNED != 0) { + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPEV8 bb = FLOAT_TYPEV8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; + return; + } #elif LOAD_VEC_B == 4 - const u16vec2 row_idx = row_ids[col]; - const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; - const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + if (ALIGNED != 0) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; #if defined(DATA_B_BF16) - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(TO_FLOAT_TYPE(data_b[idx])); #else - FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); + FLOAT_TYPEV4 bb = FLOAT_TYPEV4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; + return; + } #endif - buf_b[buf_idx + 0] = bb.xy; - buf_b[buf_idx + 1] = bb.zw; -#else // LOAD_VEC_BATCH_B == 2 const uint row_i = ic * BN + col; const uint buf_idx = col * SHMEM_STRIDE + row; if (row_i < _ne1 && block + row * 2 + 1 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), - TO_FLOAT_TYPE(data_b[idx + 1])); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), + TO_FLOAT_TYPE(data_b_scalar[idx + 1])); } else if (row_i < _ne1 && block + row * 2 < end_k) { const u16vec2 row_idx = row_ids[col]; const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; - buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + buf_b[buf_idx] = FLOAT_TYPEV2(TO_FLOAT_TYPE(data_b_scalar[idx]), 0.0f); } else { buf_b[buf_idx] = FLOAT_TYPEV2(0.0f); } -#endif } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp index cc3ea0b760..792012d57e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -1,26 +1,26 @@ #version 450 -#include "generic_head.glsl" #include "types.glsl" +#include "generic_unary_head.glsl" #extension GL_EXT_control_flow_attributes : enable #define BLOCK_SIZE 512 layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared vec2 sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint a_base = get_aoffset() + src0_idx(row * p.ne00); + const uint d_base = get_doffset() + dst_idx(row * p.ne10); + sum[tid] = vec2(0.0f, 0.0f); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const float xi = float(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const float xi = float(data_a[a_base + i0*p.nb00]); sum[tid].x += xi; sum[tid].y += xi * xi; } @@ -34,11 +34,11 @@ void main() { barrier(); } - const float mean = sum[0].x / p.KX; - const float var = sum[0].y / p.KX - mean * mean; + const float mean = sum[0].x / p.ne00; + const float var = sum[0].y / p.ne00 - mean * mean; const float inv_std = inversesqrt(var + p.param1); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[d_base + i0*p.nb10] = D_TYPE((float(data_a[a_base + i0*p.nb00]) - mean) * inv_std); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp deleted file mode 100644 index 61f17b2f00..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp deleted file mode 100644 index 70daad6c5d..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ggml/src/ggml-vulkan/vulkan-shaders/square.comp deleted file mode 100644 index 4eb56afcb1..0000000000 --- a/ggml/src/ggml-vulkan/vulkan-shaders/square.comp +++ /dev/null @@ -1,17 +0,0 @@ -#version 450 - -#include "types.glsl" -#include "generic_unary_head.glsl" - -layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; - -void main() { - const uint idx = get_idx(); - - if (idx >= p.ne) { - return; - } - - const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); - data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); -} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp index 47a4573996..c62bce8255 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/unary.comp @@ -17,6 +17,30 @@ float op_neg(float x) { return -x; } +float op_sqr(float x) { + return x * x; +} + +float op_sqrt(float x) { + return sqrt(x); +} + +float op_sin(float x) { + return sin(x); +} + +float op_cos(float x) { + return cos(x); +} + +float op_clamp(float x) { + return clamp(x, p.param1, p.param2); +} + +float op_leaky_relu(float x) { + return max(x, 0.0f) + min(x, 0.0f) * p.param1; +} + float op_step(float x) { return x >= 0.0f ? 1.0f : 0.0f; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index ca6b444314..1925582ffe 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -11,6 +11,7 @@ #include #include #include +#include #include #include #include @@ -34,6 +35,9 @@ std::mutex lock; std::vector> shader_fnames; +// Set when any shader subprocess fails (non-zero exit / stderr / launch failure) so the +// build is stopped instead of silently producing a broken libggml-vulkan. (issue #24393) +static std::atomic compile_failed{false}; std::locale c_locale("C"); std::string GLSLC = "glslc"; @@ -78,7 +82,7 @@ enum MatMulIdType { namespace { -void execute_command(std::vector& command, std::string& stdout_str, std::string& stderr_str) { +int execute_command(std::vector& command, std::string& stdout_str, std::string& stderr_str) { #ifdef _WIN32 HANDLE stdout_read, stdout_write; HANDLE stderr_read, stderr_write; @@ -127,8 +131,11 @@ void execute_command(std::vector& command, std::string& stdout_str, CloseHandle(stdout_read); CloseHandle(stderr_read); WaitForSingleObject(pi.hProcess, INFINITE); + DWORD exit_code = 1; + GetExitCodeProcess(pi.hProcess, &exit_code); CloseHandle(pi.hProcess); CloseHandle(pi.hThread); + return (int)exit_code; #else int stdout_pipe[2]; int stderr_pipe[2]; @@ -175,7 +182,9 @@ void execute_command(std::vector& command, std::string& stdout_str, close(stdout_pipe[0]); close(stderr_pipe[0]); - waitpid(pid, nullptr, 0); + int status = 0; + waitpid(pid, &status, 0); + return WIFEXITED(status) ? WEXITSTATUS(status) : -1; } #endif } @@ -372,13 +381,14 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p // } // std::cout << std::endl; - execute_command(cmd, stdout_str, stderr_str); - if (!stderr_str.empty()) { - std::cerr << "cannot compile " << name << "\n\n"; + int exit_code = execute_command(cmd, stdout_str, stderr_str); + if (exit_code != 0 || !stderr_str.empty()) { + std::cerr << "cannot compile " << name << " (exit code " << exit_code << ")\n\n"; for (const auto& part : cmd) { std::cerr << part << " "; } std::cerr << "\n\n" << stderr_str << std::endl; + compile_failed = true; return; } @@ -398,6 +408,7 @@ void string_to_spv_func(std::string name, std::string in_path, std::string out_p shader_fnames.push_back(std::make_pair(name, out_path)); } catch (const std::exception& e) { std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + compile_failed = true; } } @@ -539,11 +550,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c }; // Shaders with f16 B_TYPE - string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f32_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); // bf16 { @@ -565,8 +574,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c #endif { if (!dot2) { - string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE_SCALAR", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"B_TYPEV4", "bf16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); } } } @@ -583,8 +591,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } std::string data_a_key = "DATA_A_" + to_uppercase(tname); - // For unaligned, load one at a time for f32/f16, or two at a time for quants - std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; // For aligned matmul loads std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; @@ -597,13 +603,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c // don't generate f32 variants for coopmat2 if (!coopmat2) { - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE_SCALAR", "float"}, {"B_TYPEV4", "vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } if (tname != "f16" && tname != "f32") { - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); - string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx + "_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16" + dot2_sfx, source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE_SCALAR", "float16_t"}, {"B_TYPEV4", "f16vec4"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); } #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) @@ -850,21 +854,12 @@ void process_shaders() { string_to_spv("repeat_i32", "repeat.comp", {{"A_TYPE", "int32_t"}, {"D_TYPE", "int32_t"}}); string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("get_rows_back_f32", "get_rows_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}); string_to_spv("repeat_i16", "repeat.comp", {{"A_TYPE", "int16_t"}, {"D_TYPE", "int16_t"}}); string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - - string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); - string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("concat_i8", "concat.comp", {{"A_TYPE", "uint8_t"}, {"B_TYPE", "uint8_t"}, {"D_TYPE", "uint8_t"}}); @@ -891,6 +886,18 @@ void process_shaders() { string_to_spv("silu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_silu"}}); string_to_spv("relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_relu"}}); string_to_spv("relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_relu"}}); + string_to_spv("sqr_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqr"}}); + string_to_spv("sqr_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqr"}}); + string_to_spv("sqrt_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sqrt"}}); + string_to_spv("sqrt_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sqrt"}}); + string_to_spv("sin_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_sin"}}); + string_to_spv("sin_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_sin"}}); + string_to_spv("cos_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_cos"}}); + string_to_spv("cos_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_cos"}}); + string_to_spv("clamp_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_clamp"}}); + string_to_spv("clamp_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_clamp"}}); + string_to_spv("leaky_relu_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_leaky_relu"}}); + string_to_spv("leaky_relu_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_leaky_relu"}}); string_to_spv("neg_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_neg"}}); string_to_spv("neg_f32", "unary.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"OP", "op_neg"}}); string_to_spv("tanh_f16", "unary.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OP", "op_tanh"}}); @@ -948,7 +955,6 @@ void process_shaders() { string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); - string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); @@ -1060,6 +1066,31 @@ void process_shaders() { } } + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + std::string name = std::string("conv3d") + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv3d_mm.comp", defines); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (unroll) { + auto cm2_defines = defines; + cm2_defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm2_defines, true, false, true); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (unroll) { + auto cm1_defines = defines; + cm1_defines["COOPMAT"] = "1"; + string_to_spv(name, "conv3d_mm.comp", cm1_defines, true, true, false); + } +#endif + } + } + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); @@ -1251,6 +1282,11 @@ int main(int argc, char** argv) { process_shaders(); + if (compile_failed) { + std::cerr << "vulkan-shaders-gen: one or more shaders failed to compile" << std::endl; + return EXIT_FAILURE; + } + write_output_files(); return EXIT_SUCCESS; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6f877f15ce..c00a2e9ee9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + uint32_t num_cols; bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - use_mmvq == other.use_mmvq; + num_cols == other.num_cols && use_mmvq == other.use_mmvq; } }; @@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } @@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; uint32_t n_experts; + uint32_t num_cols; int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && - vectorized == other.vectorized; + num_cols == other.num_cols && vectorized == other.vectorized; } }; @@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } @@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, const ggml_tensor * src1, bool supports_dot_product, const std::string & vendor) { - if (src1->ne[1] == 1) { + if (src1->ne[1] <= 4) { bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; if (supports_dp4a && supports_dot_product) { switch (src1->type) { @@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.num_cols = context.dst->ne[1]; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols)); auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=1")); defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 0b605fa86b..f0ec18abd9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & const size_t dst_offset = ggml_webgpu_tensor_offset(dst); const size_t q8_src1_align_offset = ROUNDUP_POW2( dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t q8_src1_binding_size = - ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), - WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t q8_src1_binding_size = ROUNDUP_POW2( + src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); std::vector q8_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], }; @@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & uint32_t q8_wg_x = 1; uint32_t q8_wg_y = 1; const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; - const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); @@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst) { // Determine if this is a mat-vec operation - bool is_vec = (dst->ne[1] == 1); + bool use_mat_vec = (dst->ne[1] <= 4); // use MMVQ path for mat-vec bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, @@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, webgpu_pipeline pipeline; std::vector dispatches; - if (is_vec) { + if (use_mat_vec) { if (use_mmvq) { ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); } @@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (is_vec) { + if (use_mat_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; @@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, ctx->webgpu_global_ctx->vendor); if (use_mmvq) { - const size_t q8_src1_size = - src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] * + (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); res = ROUNDUP_POW2(res + q8_src1_size + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); @@ -3788,7 +3790,7 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipeline = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { +static void ggml_backend_webgpu_request_adapter(wgpu::Instance & instance, wgpu::Adapter & adapter) { wgpu::RequestAdapterOptions options = {}; #ifndef __EMSCRIPTEN__ @@ -3800,17 +3802,20 @@ static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { options.nextInChain = &adapterTogglesDesc; #endif - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - ctx->webgpu_global_ctx->adapter = std::move(adapter); - }), - UINT64_MAX); + instance.WaitAny(instance.RequestAdapter( + &options, wgpu::CallbackMode::AllowSpontaneous, + [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { + if (status != wgpu::RequestAdapterStatus::Success) { + GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); + return; + } + adapter = std::move(_adapter); + }), + UINT64_MAX); +} + +static void create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, ctx->webgpu_global_ctx->adapter); GGML_ASSERT(ctx->webgpu_global_ctx->adapter != nullptr); ctx->webgpu_global_ctx->adapter.GetLimits(&ctx->webgpu_global_ctx->capabilities.limits); @@ -4265,7 +4270,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_OP_RMS_NORM: case GGML_OP_NORM: case GGML_OP_L2_NORM: - supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + supports_op = (op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32) && ggml_is_contiguous_rows(src0); break; case GGML_OP_ROPE: supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; @@ -4543,20 +4548,7 @@ ggml_backend_reg_t ggml_backend_webgpu_reg() { // Probe for adapter support wgpu::Adapter adapter; if (ctx->webgpu_global_ctx->instance != nullptr) { - wgpu::RequestAdapterOptions options = {}; - - // probe for adapter support - ctx->webgpu_global_ctx->instance.WaitAny( - ctx->webgpu_global_ctx->instance.RequestAdapter( - &options, wgpu::CallbackMode::AllowSpontaneous, - [&adapter](wgpu::RequestAdapterStatus status, wgpu::Adapter _adapter, const char * message) { - if (status != wgpu::RequestAdapterStatus::Success) { - GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message); - return; - } - adapter = std::move(_adapter); - }), - UINT64_MAX); + ggml_backend_webgpu_request_adapter(ctx->webgpu_global_ctx->instance, adapter); } // WebGPU backend requires f16 support and, on native, implicit device synchronization. diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl index 6ff9bcf2df..78ae955e6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -103,7 +103,7 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); + let subgroup_total = subgroupAdd(acc[0][row]); if (subgroup_invocation_id == 0u) { partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } @@ -126,7 +126,7 @@ fn main( #ifdef USE_WORKGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; + partial_sums[partial_index(row, thread_id)] = acc[0][row]; } workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f0a7fbd059..ebdf09513e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -91,61 +91,67 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; #ifdef MMVQ - let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u); let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); #else let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #endif + for (var col = 0u;col < NUM_COLS;col += 1) { + #ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[col][row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } - workgroupBarrier(); + workgroupBarrier(); - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + col * params.m + row] = row_total; + } + } #endif #ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[col][row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif workgroupBarrier(); - var stride = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 08753b9d64..b0703fe906 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -32,8 +32,8 @@ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { #endif #ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -41,12 +41,18 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; + var x_vals: array; + for (var col = 0u;col < NUM_COLS;col += 1) { + x_vals[col] = src1[src1_idx_base_vec + col * (params.stride_11 / VEC_SIZE) + k]; + } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); + let w = src0[src0_idx]; + for (var col = 0u;col < NUM_COLS;col += 1) { + acc[col][row] += inner_dot(w, x_vals[col]); + } } } } @@ -60,30 +66,33 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 16 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[col][bit]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -97,35 +106,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -139,36 +150,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 20 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -182,19 +195,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 22 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -203,18 +217,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 2u); let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -228,19 +243,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 24 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -250,18 +266,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 4u); let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -275,33 +292,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 34 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -315,34 +337,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 36 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d + m; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -355,8 +382,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 84 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -379,14 +406,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 64u + i]); + x_block[col][i + 12u] = f32(src1[x_base + col * params.stride_11 + 96u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -404,30 +432,32 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qs0 = q_u32 & 0xFFFFu; let qs1 = q_u32 >> 16u; - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + sumy[0] = x_block[col][0] + x_block[col][1] + x_block[col][2] + x_block[col][3]; + sumy[1] = x_block[col][4] + x_block[col][5] + x_block[col][6] + x_block[col][7]; + sumy[2] = x_block[col][8] + x_block[col][9] + x_block[col][10] + x_block[col][11]; + sumy[3] = x_block[col][12] + x_block[col][13] + x_block[col][14] + x_block[col][15]; - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + acc1[0] = x_block[col][0] * f32(qs0 & 0x0003u) + x_block[col][2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[col][1] * f32(qs0 & 0x0300u) + x_block[col][3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[col][4] * f32(qs0 & 0x000Cu) + x_block[col][6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[col][5] * f32(qs0 & 0x0C00u) + x_block[col][7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[col][8] * f32(qs0 & 0x0030u) + x_block[col][10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[col][9] * f32(qs0 & 0x3000u) + x_block[col][11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[col][12] * f32(qs0 & 0x00C0u) + x_block[col][14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[col][13] * f32(qs0 & 0xC000u) + x_block[col][15] * f32(qs1 & 0xC000u); - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + acc[col][row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } } } } @@ -440,8 +470,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -485,12 +515,13 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 8u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -516,28 +547,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + for (var col = 0u;col < NUM_COLS;col += 1) { + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + s1 += x_block[col][l + 0u] * f32(qs & qm0); + s2 += x_block[col][l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[col][l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[col][l + 1u], (hv & hm1) == 0u); + s4 += x_block[col][l + 8u] * f32(qs & qm2); + s5 += x_block[col][l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[col][l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[col][l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[col][row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } } } @@ -550,8 +583,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 144 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -573,12 +606,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -613,23 +649,25 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[col][i] * f32(q1b & 0x0Fu); + dot[1] += x_block[col][i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[col][i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[col][i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[col][i]; + sumx[1] += x_block[col][i + 4u]; + sumx[2] += x_block[col][i + 8u]; + sumx[3] += x_block[col][i + 12u]; + } - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + acc[col][row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } } } } @@ -642,8 +680,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 176 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -671,14 +709,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -712,37 +752,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); + for (var col = 0u;col < NUM_COLS;col += 1) { + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; + let yl0 = x_block[col][i]; + let yl8 = x_block[col][i + 4u]; + let yh0 = x_block[col][i + 8u]; + let yh8 = x_block[col][i + 12u]; - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[col][row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); } } } @@ -755,8 +797,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 210 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -777,14 +819,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var l = 0u; l < 4u; l++) { + x_block[col][l] = f32(src1[col_base + l]); + x_block[col][l + 4u] = f32(src1[col_base + 32u + l]); + x_block[col][l + 8u] = f32(src1[col_base + 64u + l]); + x_block[col][l + 12u] = f32(src1[col_base + 96u + l]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -802,26 +846,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - var sums = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sums = vec4(0.0, 0.0, 0.0, 0.0); - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; + sums[0] += x_block[col][l] * dq0; + sums[1] += x_block[col][l + 4u] * dq1; + sums[2] += x_block[col][l + 8u] * dq2; + sums[3] += x_block[col][l + 12u] * dq3; + } + + acc[col][row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); } } } @@ -834,8 +880,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 50 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -850,11 +896,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -866,20 +913,22 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -892,8 +941,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 56 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -908,11 +957,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -936,26 +986,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_lo = qh & 0xFFu; let qh_hi = (qh >> 8u) & 0xFFu; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -968,8 +1020,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 66 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -984,11 +1036,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -999,22 +1052,24 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux_hi >> 28u; let db = d * (0.5 + f32(ls)) * 0.25; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1027,8 +1082,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 74 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1043,11 +1098,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1058,27 +1114,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(scales_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1091,8 +1149,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 82 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1107,11 +1165,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1124,24 +1183,26 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(sc_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1154,8 +1215,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 98 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1170,11 +1231,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1186,27 +1248,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux >> 28u; let db = d * (0.5 + f32(ls)) * 0.5; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1219,8 +1283,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1235,11 +1299,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1255,28 +1320,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; let db = d * (1.0 + 2.0 * f32(sub_scale)); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1290,35 +1357,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + i + 16u]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1331,8 +1400,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 136 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1346,11 +1415,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1370,17 +1440,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[col][i]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1394,35 +1466,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 17 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); let e = ldexp(1.0, i32(eu8) - 128); - var row_sum = 0.0; let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl index 3ef2f77ebe..6ccaf61a6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_0 #ifdef MUL_ACC_Q4_1 #define BLOCK_SIZE_BYTES 20 @@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2 { f32(load_f16_at_src0(block_byte_base + 2u)) ); } -fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_1 #ifdef MUL_ACC_Q8_0 #define BLOCK_SIZE_BYTES 34 @@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds); -} -#endif +#endif // MUL_ACC_Q8_0 -#ifdef LEGACY_QUANTS -fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { - var row_sum = 0; - let a_repacked = repack_a(a_byte_base, b_inner_id); - - row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); - row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); - - return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); -} - -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +#if defined(LEGACY_QUANTS) +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let b_inner_id = thread_id % THREADS_PER_BLOCK; - let b_block_idx = src1q_idx_base + block; - - let b_repacked = repack_b_qs(b_block_idx, b_inner_id); - let b_ds = repack_b_dm(b_block_idx); - + let inner_id = thread_id % THREADS_PER_BLOCK; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, inner_id); + let da = get_dm(block_byte_base); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block; + let b_repacked = repack_b_qs(src1q_idx, inner_id); + let b_ds = repack_b_dm(src1q_idx); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]); + +#if defined(MUL_ACC_Q4_0) + acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_0 + +#if defined(MUL_ACC_Q4_1) + acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_1 + +#if defined(MUL_ACC_Q8_0) + acc[col][row] += f32(row_sum) * (da * b_ds); +#endif // MUL_ACC_Q8_0 + } } } } return acc; } -#endif +#endif // LEGACY_QUANTS #ifdef MUL_ACC_Q2_K #define BLOCK_SIZE_BYTES 84 @@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); return vec2(f32(scale & 0xFu), f32(scale >> 4u)); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let scale_q = i32(scale_min.x); - let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; - - let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) - + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; - let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) - + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); - - return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); -} -#endif +#endif // MUL_ACC_Q2_K #ifdef MUL_ACC_Q4_K #define BLOCK_SIZE_BYTES 144 @@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { return vec2(scale, min_val); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) - + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); - - // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. - return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); -} -#endif +#endif // MUL_ACC_Q4_K #ifdef K_QUANTS -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { - let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; - let b_repacked = repack_b_qs(src1q_idx, tid); - let b_ds = repack_b_dm(src1q_idx); - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, tid); + let dm = get_dm(block_byte_base); + let scale_min = get_scale_min(block_byte_base, tid); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + +#if defined(MUL_ACC_Q2_K) + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +#endif // MUL_ACC_Q2_K + +#if defined(MUL_ACC_Q4_K) + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +#endif // MUL_ACC_Q4_K + + } } } } return acc; } -#endif +#endif // K_QUANTS diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl index b3f1fa04b8..847b27ffad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product; struct Params { offset_src1: u32, + stride_11: u32, stride_12: u32, stride_13: u32, ne0: u32, + ne1: u32, ne2: u32, ne3: u32, }; @@ -57,25 +59,28 @@ fn main( @builtin(num_workgroups) num_wg: vec3 ) { let thread_id = local_id.x; - let num_vec4 = params.ne0 / 4u; + let ne0_vec4 = params.ne0 / 4u; - let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; - let total_batches = wg_per_vec * params.ne2 * params.ne3; + let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; if (wg_linear >= total_batches) { return; } - let src13_idx = wg_linear / (params.ne2 * wg_per_vec); - let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; - let src11_wg_idx = wg_linear % wg_per_vec; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let vec_idx = wg_linear / wg_per_vec; + let src13_idx = vec_idx / (params.ne2 * params.ne1); + let vec_ne12_num = vec_idx % (params.ne2 * params.ne1); + let src12_idx = vec_ne12_num / params.ne1; + let src11_idx = vec_ne12_num % params.ne1; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11; let src1_idx_vec4_base = src1_idx_base / 4u; let blocks_per_row = params.ne0 / 32u; let blocks_per_wg = (WG_SIZE * 4u) / 32u; - let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row; + let src11_wg_idx = wg_linear % wg_per_vec; let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; let qs_idx = thread_id % 8u; @@ -85,7 +90,7 @@ fn main( var thread_amax = 0.0; let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; - let is_valid = src11_vec4_idx < num_vec4; + let is_valid = src11_vec4_idx < ne0_vec4; #ifdef USE_SUBGROUP_REDUCTION diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index b43016c87d..0f682fd185 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -600,18 +600,15 @@ FILE * ggml_fopen(const char * fname, const char * mode) { // convert fname (UTF-8) wchar_t * wfname = ggml_mbstowcs(fname); if (wfname) { - // convert mode (ANSI) - wchar_t * wmode = GGML_MALLOC((strlen(mode) + 1) * sizeof(wchar_t)); - wchar_t * wmode_p = wmode; - do { - *wmode_p++ = (wchar_t)*mode; - } while (*mode++); - - // open file - file = _wfopen(wfname, wmode); + // convert mode (UTF-8) + wchar_t * wmode = ggml_mbstowcs(mode); + if (wmode) { + // open file + file = _wfopen(wfname, wmode); + GGML_FREE(wmode); + } GGML_FREE(wfname); - GGML_FREE(wmode); } return file; diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 463963f2ac..1bda9452dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -359,6 +359,7 @@ class Keys: CHUNK_SIZE = "clip.audio.chunk_size" CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size" MAX_POS_EMB = "clip.audio.max_pos_emb" + FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus class Attention: HEAD_COUNT = "clip.audio.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f707f29dc5..a06ec88b32 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1310,6 +1310,9 @@ class GGUFWriter: def add_audio_max_pos_emb(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value) + def add_audio_feature_layers(self, layers: Sequence[int]) -> None: + self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers) + def add_audio_projector_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value) diff --git a/include/llama.h b/include/llama.h index 27e4806742..f723c9f60c 100644 --- a/include/llama.h +++ b/include/llama.h @@ -558,14 +558,15 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); - LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_ctx_train (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_layer_nextn(const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_swa (const struct llama_model * model); // Get the model's RoPE frequency scaling factor LLAMA_API float llama_model_rope_freq_scale_train(const struct llama_model * model); diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index fe14bb1422..f7622eb527 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -57,19 +57,25 @@ oppoll= opflt= [ "$OF" != "" ] && opflt="GGML_HEXAGON_OPFILTER=$OF" +opfuse= +[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC" + vmem= [ "$VM" != "" ] && vmem="GGML_HEXAGON_VMEM=$VM" mbuf= [ "$MB" != "" ] && mbuf="GGML_HEXAGON_MBUF=$MB" +mmsel= +[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM" + set -x adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $vmem $mbuf \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opflt $opfuse $vmem $mbuf $mmsel \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ --ctx-size 8192 --ubatch-size 1024 -fa on \ diff --git a/scripts/snapdragon/adb/run-tool.sh b/scripts/snapdragon/adb/run-tool.sh index 6d7e32b321..f6332391bc 100755 --- a/scripts/snapdragon/adb/run-tool.sh +++ b/scripts/snapdragon/adb/run-tool.sh @@ -51,6 +51,12 @@ opqueue= oppoll= [ "$OP" != "" ] && oppoll="GGML_HEXAGON_OPPOLL=$OP" +opfuse= +[ "$OC" != "" ] && opfuse="GGML_HEXAGON_OPFUSION=$OC" + +mmsel= +[ "$MM" != "" ] && mmsel="GGML_HEXAGON_MM_SELECT=$MM" + set -x tool=$1; shift @@ -59,5 +65,5 @@ adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll ./$branch/bin/$tool $@ \ + $verbose $sched $opmask $profile $nhvx $hmx $ndev $hb $opbatch $opqueue $oppoll $opfuse $mmsel ./$branch/bin/$tool $@ \ " diff --git a/scripts/snapdragon/ggml-hexagon-profile.py b/scripts/snapdragon/ggml-hexagon-profile.py index 05045262f2..c53ad77793 100755 --- a/scripts/snapdragon/ggml-hexagon-profile.py +++ b/scripts/snapdragon/ggml-hexagon-profile.py @@ -26,7 +26,7 @@ COL_MAP = { } op_pattern = re.compile( - r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" + r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+.*?\s+:\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" ) trace_pattern = re.compile( @@ -93,9 +93,40 @@ def parse_log(file_path, pmu_index=None): + int(ts_match.group('us')) ) - op_match = op_pattern.search(line) + if "|" in line and "profile-op" in line: + parts = [p.strip() for p in line.split("|")] + prefix = parts[0] + prefix_match = re.search(r"profile-op\s+(?P[A-Z_0-9+]+)", prefix) + if not prefix_match: + continue + + if len(parts) == 7: + dims, types, timings = parts[2], parts[3], parts[6] + elif len(parts) == 6: + dims, types, timings = parts[2], parts[3], parts[5] + else: + continue + + timing_match = re.search( + r"(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?", + timings + ) + if not timing_match: + continue + + op_match = timing_match + op_name = prefix_match.group("op_name") + else: + op_match = op_pattern.search(line) + if op_match: + op_name = op_match.group('op_name') + dims = op_match.group('dims').strip() + types = op_match.group('types').strip() + else: + op_match = None + if op_match: - pmu_raw = op_match.group('pmu') + pmu_raw = op_match.group('pmu') if 'pmu' in op_match.groupdict() else None pmu_val = None if pmu_raw and pmu_index is not None: try: @@ -105,7 +136,7 @@ def parse_log(file_path, pmu_index=None): except (ValueError, IndexError): pmu_val = None - evt_raw = op_match.group('evt') + evt_raw = op_match.group('evt') if 'evt' in op_match.groupdict() else None evt_val = None if evt_raw: try: @@ -122,9 +153,9 @@ def parse_log(file_path, pmu_index=None): op_text = line[idx + 11:].strip() if idx != -1 else line.strip() current_op = { - 'name': op_match.group('op_name'), - 'dims': op_match.group('dims').strip(), - 'types': op_match.group('types').strip(), + 'name': op_name, + 'dims': dims, + 'types': types, 'op_text': op_text, 'usec': int(op_match.group('usec')), 'cycles': int(op_match.group('cycles')), diff --git a/scripts/snapdragon/ggml-hexagon-trace.py b/scripts/snapdragon/ggml-hexagon-trace.py index 18ec440a9f..37f137a9e7 100755 --- a/scripts/snapdragon/ggml-hexagon-trace.py +++ b/scripts/snapdragon/ggml-hexagon-trace.py @@ -12,7 +12,7 @@ from collections import defaultdict logger = logging.getLogger("ggml-hexagon-trace") op_pattern = re.compile( - r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" + r"profile-op\s+(?P[A-Z_0-9+]+):\s+.*?\s+:\s+(?P[\d:x\s\->!]+)\s+:\s+(?P[a-z\d_\s\->x]+)\s+:\s+(?P[\d:x\s\->!]+?)\s+:\s+(?:(?P.*?)\s+:\s+)?(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?" ) trace_pattern = re.compile( @@ -66,7 +66,40 @@ def parse_log(file_path): for line in f: line_idx += 1 - op_match = op_pattern.search(line) + if "|" in line and "profile-op" in line: + parts = [p.strip() for p in line.split("|")] + prefix = parts[0] + prefix_match = re.search(r"profile-op\s+(?P[A-Z_0-9+]+)", prefix) + if not prefix_match: + continue + + if len(parts) == 7: + dims, types, strides, params, timings = parts[2], parts[3], parts[4], parts[5], parts[6] + elif len(parts) == 6: + dims, types, strides, params, timings = parts[2], parts[3], parts[4], "", parts[5] + else: + continue + + timing_match = re.search( + r"(?:op-)?usec\s+(?P\d+)\s+(?:op-)?cycles\s+(?P\d+)(?:\s+start\s+(?P\d+))?(?:\s+mhz\s+(?P[\d.]+))?(?:\s+pmu\s+\[(?P[\d,\s]+)\])?(?:\s+evt\s+\[(?P[\d,\s]+)\])?", + timings + ) + if not timing_match: + continue + + op_match = timing_match + op_name = prefix_match.group("op_name") + else: + op_match = op_pattern.search(line) + if op_match: + op_name = op_match.group('op_name') + dims = op_match.group('dims').strip() if op_match.group('dims') else '' + types = op_match.group('types').strip() if op_match.group('types') else '' + strides = op_match.group('strides').strip() if op_match.group('strides') else '' + params = op_match.group('params').strip() if ('params' in op_match.groupdict() and op_match.group('params')) else '' + else: + op_match = None + if op_match: cycles_start_raw = op_match.group('start') unwrapped_cycles_start = None @@ -77,10 +110,11 @@ def parse_log(file_path): op_text = line[idx + 11:].strip() if idx != -1 else line.strip() current_op = { - 'name': op_match.group('op_name'), - 'dims': op_match.group('dims').strip() if op_match.group('dims') else '', - 'types': op_match.group('types').strip() if op_match.group('types') else '', - 'strides': op_match.group('strides').strip() if op_match.group('strides') else '', + 'name': op_name, + 'dims': dims, + 'types': types, + 'strides': strides, + 'params': params, 'op_text': op_text, 'usec': int(op_match.group('usec')), 'cycles': int(op_match.group('cycles')), @@ -397,6 +431,8 @@ def generate_perfetto_trace(filtered_ops, output_path): debug_annots.append(make_debug_annotation("line", int_val=op['line_num'])) if 'strides' in op and op['strides']: debug_annots.append(make_debug_annotation("strides", string_val=op['strides'])) + if 'params' in op and op['params'] and op['params'] != '----': + debug_annots.append(make_debug_annotation("params", string_val=op['params'])) # Slice Begin evt_begin = make_track_event(1, 2, name=f"{op['name']} ({op['dims']})", category="operator", debug_annotations=debug_annots) diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 87d353ef45..499be5a585 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -3af5f5760e19a96427f5f7a93b79cbdf3d4b265b +707321c4cf6d21cb4bc831aa8b687dbf01a521ce diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 402d7bbad3..f913b0c7dc 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "refs/tags/v0.47.0" +HTTPLIB_VERSION = "refs/tags/v0.48.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", diff --git a/scripts/ui-assets.cmake b/scripts/ui-assets.cmake index 78a0f4c844..349fa9bf81 100644 --- a/scripts/ui-assets.cmake +++ b/scripts/ui-assets.cmake @@ -20,6 +20,7 @@ set(LLAMA_UI_GZIP "" CACHE STRING "Apply gzip compress to assets to save ban set(DIST_DIR "${UI_BINARY_DIR}/dist") set(SRC_DIST_DIR "${UI_SOURCE_DIR}/dist") +set(WORK_DIR "${UI_BINARY_DIR}/ui-src") set(STAMP_FILE "${UI_BINARY_DIR}/.ui-stamp") set(UI_CPP "${UI_BINARY_DIR}/ui.cpp") set(UI_H "${UI_BINARY_DIR}/ui.h") @@ -64,6 +65,22 @@ function(npm_build_should_skip out_var) set(${out_var} TRUE PARENT_SCOPE) endfunction() +function(stage_sources) + if(EXISTS "${WORK_DIR}") + file(GLOB staged RELATIVE "${WORK_DIR}" "${WORK_DIR}/*") + list(REMOVE_ITEM staged "node_modules") + foreach(entry ${staged}) + file(REMOVE_RECURSE "${WORK_DIR}/${entry}") + endforeach() + endif() + + file(COPY "${UI_SOURCE_DIR}/" + DESTINATION "${WORK_DIR}" + NO_SOURCE_PERMISSIONS + PATTERN "node_modules" EXCLUDE + ) +endfunction() + function(npm_build out_var) set(${out_var} FALSE PARENT_SCOPE) @@ -89,14 +106,16 @@ function(npm_build out_var) return() endif() + stage_sources() + # npm writes node_modules/.package-lock.json on every successful install, # so a package-lock.json newer than this marker means node_modules is stale - set(NPM_MARKER "${UI_SOURCE_DIR}/node_modules/.package-lock.json") + set(NPM_MARKER "${WORK_DIR}/node_modules/.package-lock.json") set(need_install FALSE) if(NOT EXISTS "${NPM_MARKER}") set(need_install TRUE) else() - file(TIMESTAMP "${UI_SOURCE_DIR}/package-lock.json" lock_ts) + file(TIMESTAMP "${WORK_DIR}/package-lock.json" lock_ts) file(TIMESTAMP "${NPM_MARKER}" marker_ts) if(lock_ts STRGREATER marker_ts) set(need_install TRUE) @@ -107,7 +126,7 @@ function(npm_build out_var) message(STATUS "UI: running npm install") execute_process( COMMAND ${NPM_EXECUTABLE} install - WORKING_DIRECTORY "${UI_SOURCE_DIR}" + WORKING_DIRECTORY "${WORK_DIR}" RESULT_VARIABLE rc ERROR_VARIABLE err ) @@ -124,7 +143,7 @@ function(npm_build out_var) execute_process( COMMAND ${CMAKE_COMMAND} -E env "LLAMA_UI_OUT_DIR=${DIST_DIR}" "LLAMA_UI_VERSION=${HF_VERSION}" "LLAMA_BUILD_NUMBER=${LLAMA_BUILD_NUMBER}" ${NPM_EXECUTABLE} run build - WORKING_DIRECTORY "${UI_SOURCE_DIR}" + WORKING_DIRECTORY "${WORK_DIR}" RESULT_VARIABLE rc ERROR_VARIABLE err ) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 529bc4a5e9..220240ea95 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1156,6 +1156,10 @@ void llama_context::set_embeddings_layer_inp(uint32_t lid, bool enable) { sched_need_reserve = true; } +void llama_context::set_nextn_layer_offset(int32_t offset) { + cparams.nextn_layer_offset = offset; +} + void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); @@ -3699,6 +3703,10 @@ void llama_set_embeddings_layer_inp(llama_context * ctx, uint32_t lid, bool valu ctx->set_embeddings_layer_inp(lid, value); } +void llama_set_nextn_layer_offset(llama_context * ctx, int32_t offset) { + ctx->set_nextn_layer_offset(offset); +} + llama_memory_t llama_get_memory(const struct llama_context * ctx) { if (!ctx) { return nullptr; diff --git a/src/llama-context.h b/src/llama-context.h index 853052be2c..f8b7805871 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -115,6 +115,7 @@ struct llama_context { void set_embeddings (bool value); void set_embeddings_nextn(bool value, bool masked); void set_embeddings_layer_inp(uint32_t lid, bool enable); + void set_nextn_layer_offset(int32_t offset); void set_causal_attn(bool value); void set_warmup(bool value); diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 2b109f909c..546ae1e2c1 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -18,6 +18,8 @@ struct llama_cparams { int32_t n_threads; // number of threads to use for generation int32_t n_threads_batch; // number of threads to use for batch processing + int32_t nextn_layer_offset = 0; + float rope_freq_base; float rope_freq_scale; diff --git a/src/llama-ext.h b/src/llama-ext.h index 8b5679b690..348bbae957 100644 --- a/src/llama-ext.h +++ b/src/llama-ext.h @@ -95,6 +95,11 @@ LLAMA_API llama_memory_breakdown llama_get_memory_breakdown(const struct llama_c // If masked == false, output the embeddings for all tokens in the batch regardless of batch.logits LLAMA_API void llama_set_embeddings_nextn(struct llama_context * ctx, bool value, bool masked); +// Select which appended NextN block the DECODER_MTP graph runs (offset past +// the trunk: il = n_layer() + offset). Used by the speculative NextN driver to +// chain multiple trained NextN heads. Default 0 (first head). +LLAMA_API void llama_set_nextn_layer_offset(struct llama_context * ctx, int32_t offset); + // mirrors: // LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings_nextn(struct llama_context * ctx); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5e8a658350..a6e8c3985b 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -682,9 +682,16 @@ struct llm_graph_params { } } + // TODO: https://github.com/ggml-org/llama.cpp/pull/24340#discussion_r3448035248 + if (cparams.nextn_layer_offset != other.cparams.nextn_layer_offset) { + return false; + } + return - cparams.embeddings == other.cparams.embeddings && - cparams.causal_attn == other.cparams.causal_attn && + cparams.embeddings == other.cparams.embeddings && + cparams.embeddings_nextn == other.cparams.embeddings_nextn && + cparams.embeddings_nextn_masked == other.cparams.embeddings_nextn_masked && + cparams.causal_attn == other.cparams.causal_attn && arch == other.arch && gtype == other.gtype && cvec == other.cvec && diff --git a/src/llama-model.cpp b/src/llama-model.cpp index c528755339..d041a9ce3e 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -2312,6 +2312,10 @@ int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer(); } +int32_t llama_model_n_layer_nextn(const llama_model * model) { + return model->hparams.n_layer_nextn; +} + int32_t llama_model_n_head(const llama_model * model) { return model->hparams.n_head(); } diff --git a/src/llama-quant.cpp b/src/llama-quant.cpp index cf92ce4bb8..847e79f465 100644 --- a/src/llama-quant.cpp +++ b/src/llama-quant.cpp @@ -847,7 +847,7 @@ static void init_quantize_state_counters(quantize_state_impl & qs, std::vectordata[i].logit = -INFINITY; } } - - llama_sampler_softmax_impl(cur_p, true); } static struct llama_sampler * llama_sampler_top_n_sigma_clone(const struct llama_sampler * smpl) { diff --git a/src/models/glm-dsa.cpp b/src/models/glm-dsa.cpp index 11d91312de..32fe6def6f 100644 --- a/src/models/glm-dsa.cpp +++ b/src/models/glm-dsa.cpp @@ -101,11 +101,11 @@ void llama_model_glm_dsa::load_arch_tensors(llama_model_loader &) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); // DSA indexer - layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); - layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); - layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); - layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); - layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags | TENSOR_NOT_REQUIRED); if (i < (int) hparams.n_layer_dense_lead) { layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index 97da8a6abb..07b7346ee4 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -190,7 +190,15 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - bx = ggml_concat(ctx0, conv, bx, 0); + // causal prepends the state, non-causal pads symmetrically for a centered window + if (hparams.causal_attn) { + bx = ggml_concat(ctx0, conv, bx, 0); + } else { + const int64_t pad = (hparams.n_shortconv_l_cache - 1) / 2; + auto * left = ggml_cont(ctx0, + ggml_view_3d(ctx0, conv, pad, hparams.n_embd, n_seqs, conv->nb[1], conv->nb[2], (d_conv - pad) * conv->nb[0])); + bx = ggml_pad_ext(ctx0, ggml_concat(ctx0, left, bx, 0), 0, pad, 0, 0, 0, 0, 0, 0); + } GGML_ASSERT(bx->ne[0] > conv->ne[0]); // last d_conv columns is a new conv state @@ -266,10 +274,12 @@ llama_model_lfm2::graph::graph(const llama_model & model, const llm_graph_ cb(cur, "result_norm", -1); res->t_embd = cur; - cur = build_lora_mm(model.output, cur, model.output_s); - cb(cur, "result_output", -1); + if (!cparams.embeddings) { + cur = build_lora_mm(model.output, cur, model.output_s); + cb(cur, "result_output", -1); - res->t_logits = cur; + res->t_logits = cur; + } ggml_build_forward_expand(gf, cur); } diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 6783d98ec2..d8ffe43ae7 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -156,6 +156,8 @@ llama_model_qwen35::graph::graph(const llama_model & model, const llm_graph_para // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index eb5e9a406a..7b0876cbb0 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -179,6 +179,8 @@ llama_model_qwen35moe::graph::graph(const llama_model & model, const llm_graph_p // MTP/NextN layers are loaded as extra decoder blocks but not executed in the main pass. for (int il = 0; il < n_layer; ++il) { + res->t_layer_inp[il] = inpL; + ggml_tensor * inpSA = inpL; cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); diff --git a/src/models/step35.cpp b/src/models/step35.cpp index e2218c5870..9b7b18a367 100644 --- a/src/models/step35.cpp +++ b/src/models/step35.cpp @@ -112,7 +112,7 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); }; - auto load_block_mtp = [&](int i, bool is_first_mtp) { + auto load_block_mtp = [&](int i) { auto & layer = layers[i]; const uint32_t n_head_l = hparams.n_head(i); @@ -121,15 +121,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { // The MTP block is a full Step3p5 decoder layer (mtp_block) plus the // NextN-specific wiring (enorm/hnorm/eh_proj + optional shared head). - // `mtp_flags` becomes NOT_REQUIRED when the GGUF is trunk-only. - // - // Only the FIRST MTP block (i == n_main) is required for the - // single-block MTP runtime; trailing MTP blocks are always tolerated - // as missing so pruned GGUFs (block 0 only) load cleanly. Override - // mtp_flags to NOT_REQUIRED for those. - const int eff_mtp_flags = is_first_mtp ? mtp_flags : (mtp_flags | TENSOR_NOT_REQUIRED); + // Multi-block MTP: every declared MTP block is required (the draft chain + // runs all n_layer_nextn heads), so each block uses the captured + // `mtp_flags` directly โ€” already NOT_REQUIRED for a trunk-only GGUF, + // which keeps that path correct. - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, mtp_flags); layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, TENSOR_NOT_REQUIRED); @@ -140,12 +137,12 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot_max/2}, TENSOR_NOT_REQUIRED | TENSOR_DUPLICATED); } - create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, eff_mtp_flags); - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, eff_mtp_flags); + create_tensor_qkv(layer, i, n_embd, n_embd_head_k * n_head_l, n_embd_k_gqa, n_embd_v_gqa, mtp_flags); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_v * n_head_l, n_embd}, mtp_flags); layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), {n_embd, n_head_l}, TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, eff_mtp_flags); + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, mtp_flags); // dense MLP (leading dense blocks) โ€” present if the MTP block isn't MoE layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); @@ -165,9 +162,9 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), {hparams.n_ff_shexp, n_embd}, TENSOR_NOT_REQUIRED); // NextN-specific tensors that define the MTP block. - layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, eff_mtp_flags); - layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, eff_mtp_flags); - layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, eff_mtp_flags); + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, mtp_flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, mtp_flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, mtp_flags); layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, TENSOR_NOT_REQUIRED); @@ -176,13 +173,11 @@ void llama_model_step35::load_arch_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { load_block_trunk(i, trunk_flags); } - // Only the first MTP block (i == n_main) is required at runtime โ€” the - // single-block-MTP graph in build_arch_graph always uses that one. - // Trailing MTP blocks are loaded if present (so an un-pruned GGUF with - // all MTP layers still works) but tolerated when absent via the pruning - // path. See scripts/prune_step35_extra_mtp.py for the pruner. + // All n_layer_nextn MTP blocks are required โ€” the multi-block draft chain + // runs every head (head k at offset k). The GGUF declares the count via + // step35.nextn_predict_layers. for (int i = n_layer; i < n_layer_all; ++i) { - load_block_mtp(i, /*is_first_mtp=*/ i == n_layer); + load_block_mtp(i); } } @@ -372,13 +367,14 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr : llm_graph_context(params) { GGML_ASSERT(hparams.n_layer_nextn > 0 && "STEP35 MTP requires n_layer_nextn > 0"); - // Single-block MTP only: always run the first trained MTP block (Qwen - // MTP / vLLM single-MTP-layer style). Multi-block round-robin proved to - // be a much deeper refactor than this PR justifies; the trailing MTP - // blocks are loaded with TENSOR_NOT_REQUIRED so pruned GGUFs (with just - // block 0) also work โ€” see load_arch_tensors below and - // scripts/prune_step35_extra_mtp.py. - const int il = hparams.n_layer(); + // Multi-block MTP: the DECODER_MTP graph runs the MTP head selected by + // cparams.nextn_layer_offset (0 = first trained head). The speculative driver + // bumps the offset per draft step to chain heads 45->46->47. offset 0 keeps + // single-block behavior identical to before. + const int il = hparams.n_layer() + cparams.nextn_layer_offset; + GGML_ASSERT(cparams.nextn_layer_offset >= 0 && + cparams.nextn_layer_offset < (int) hparams.n_layer_nextn && + "nextn_layer_offset out of range [0, n_layer_nextn)"); const auto & layer = model.layers[il]; GGML_ASSERT(layer.nextn.eh_proj && "MTP block missing nextn.eh_proj"); @@ -536,6 +532,9 @@ llama_model_step35::graph_mtp::graph_mtp(const llama_model & model, const llm_gr cur = ggml_add(ctx0, cur, ffn_inp); cb(cur, "mtp_post_ffn", il); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + // Pre-norm hidden state: used by the AR draft loop to seed the next MTP step. cb(cur, "h_nextn", -1); res->t_h_nextn = cur; diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index e284a58d1c..0dd1d7b162 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -199,7 +199,6 @@ llama_build_and_test(test-jinja.cpp) llama_test(test-jinja NAME test-jinja-py ARGS -py LABEL python) llama_build_and_test(test-chat-auto-parser.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) llama_build_and_test(test-chat-template.cpp) -llama_build_and_test(test-json-partial.cpp) llama_build_and_test(test-log.cpp) llama_build_and_test( test-peg-parser.cpp diff --git a/tests/peg-parser/test-gbnf-generation.cpp b/tests/peg-parser/test-gbnf-generation.cpp index 00111e6a19..60066a817b 100644 --- a/tests/peg-parser/test-gbnf-generation.cpp +++ b/tests/peg-parser/test-gbnf-generation.cpp @@ -129,7 +129,154 @@ void test_gbnf_generation(testing &t) { }); assert_gbnf_equal(t, R"""( - root ::= ([^<] | "<" [^/] | "])* ("<" | "] until-0 + )""", gbnf); + }); + + t.test("until grammar overlapping delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("\n\n"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [\n] until-0-01 | [^\n] until-0 + until-0-01 ::= | [\n] until-0-01 | [<] until-0-02 | [^\n<] until-0 + until-0-02 ::= | [\n] until-0-01 | [/] until-0-03 | [^\n/] until-0 + until-0-03 ::= | [\n] until-0-01 | [p] until-0-04 | [^\np] until-0 + until-0-04 ::= | [\n] until-0-01 | [a] until-0-05 | [^\na] until-0 + until-0-05 ::= | [\n] until-0-01 | [r] until-0-06 | [^\nr] until-0 + until-0-06 ::= | [\n] until-0-01 | [a] until-0-07 | [^\na] until-0 + until-0-07 ::= | [\n] until-0-01 | [m] until-0-08 | [^\nm] until-0 + until-0-08 ::= | [\n] until-0-01 | [e] until-0-09 | [^\ne] until-0 + until-0-09 ::= | [\n] until-0-01 | [t] until-0-10 | [^\nt] until-0 + until-0-10 ::= | [\n] until-0-01 | [e] until-0-11 | [^\ne] until-0 + until-0-11 ::= | [\n] until-0-01 | [r] until-0-12 | [^\nr] until-0 + until-0-12 ::= | [\n] until-0-01 | [>] until-0-13 | [^\n>] until-0 + until-0-13 ::= | [^\n] until-0 + )""", gbnf); + }); + + // DeepSeek-V3.2 tag prefix. The DSML token (๏ฝœDSML๏ฝœ) embeds U+FF5C, + // so the delimiter mixes ASCII and multi-byte codepoints. + t.test("until grammar unicode delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.until("<๏ฝœDSML๏ฝœ"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + root ::= until-0 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + until-0 ::= | [<] until-0-01 | [^<] until-0 + until-0-01 ::= | [<] until-0-01 | [\uFF5C] until-0-02 | [^<\uFF5C] until-0 + until-0-02 ::= | [<] until-0-01 | [D] until-0-03 | [^") + p.literal(""), ""); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-3 ::= [<] ac-3-01 | [^<] ac-3 + ac-3-01 ::= [<] ac-3-01 | [/] ac-3-02 | [^/<] ac-3 + ac-3-02 ::= [<] ac-3-01 | [t] ac-3-03 | [^] | [<] ac-3-01 | [^<>] ac-3 + root ::= ac-3 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("ac grammar terminates at first delimiter", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.ac(p.until("\n\n") + p.literal("\n\n"), "\n\n"); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-3 ::= [\n] ac-3-01 | [^\n] ac-3 + ac-3-01 ::= [\n] ac-3-01 | [<] ac-3-02 | [^\n<] ac-3 + ac-3-02 ::= [\n] ac-3-01 | [/] ac-3-03 | [^\n/] ac-3 + ac-3-03 ::= [\n] ac-3-01 | [p] ac-3-04 | [^\np] ac-3 + ac-3-04 ::= [\n] ac-3-01 | [a] ac-3-05 | [^\na] ac-3 + ac-3-05 ::= [\n] ac-3-01 | [r] ac-3-06 | [^\nr] ac-3 + ac-3-06 ::= [\n] ac-3-01 | [a] ac-3-07 | [^\na] ac-3 + ac-3-07 ::= [\n] ac-3-01 | [m] ac-3-08 | [^\nm] ac-3 + ac-3-08 ::= [\n] ac-3-01 | [e] ac-3-09 | [^\ne] ac-3 + ac-3-09 ::= [\n] ac-3-01 | [t] ac-3-10 | [^\nt] ac-3 + ac-3-10 ::= [\n] ac-3-01 | [e] ac-3-11 | [^\ne] ac-3 + ac-3-11 ::= [\n] ac-3-01 | [r] ac-3-12 | [^\nr] ac-3 + ac-3-12 ::= [\n] ac-3-01 | [>] ac-3-13 | [^\n>] ac-3 + ac-3-13 ::= [\n] | [^\n] ac-3 + root ::= ac-3 + space ::= | " " | "\n"{1,2} [ \t]{0,20} + )""", gbnf); + }); + + t.test("ac grammar multiple delimiters", [](testing &t) { + auto parser = build_peg_parser([](common_peg_parser_builder & p) { + return p.ac(p.eps(), std::vector{"ab", "cd", "ef"}); + }); + + auto gbnf = build_grammar([&](const common_grammar_builder & builder) { + parser.build_grammar(builder); + }); + + assert_gbnf_equal(t, R"""( + ac-1 ::= [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^ace] ac-1 + ac-1-01 ::= [b] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^abce] ac-1 + ac-1-03 ::= [d] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acde] ac-1 + ac-1-05 ::= [f] | [a] ac-1-01 | [c] ac-1-03 | [e] ac-1-05 | [^acef] ac-1 + root ::= ac-1 space ::= | " " | "\n"{1,2} [ \t]{0,20} )""", gbnf); }); diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 0dd8422e73..e83ee85dd4 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -10,7 +10,7 @@ #undef NDEBUG #include -int main(void) { +static void test(void) { common_params params; printf("test-arg-parser: make sure there is no duplicated arguments in any examples\n\n"); @@ -210,3 +210,13 @@ int main(void) { printf("test-arg-parser: all tests OK\n\n"); } + +int main(void) { + try { + test(); + } catch (std::exception & e) { + fprintf(stderr, "test-arg-parser: exception: %s\n", e.what()); + return 1; + } + return 0; +} diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 15ae38927c..3f18dbe220 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3298,21 +3298,29 @@ struct test_norm : public test_case { const std::array ne; const bool v; // whether a is a non-contiguous view const float eps; + const bool noncontig_rows; std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, noncontig_rows); } test_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, bool v = false, - float eps = 1e-6f) - : type(type), ne(ne), v(v), eps(eps) {} + float eps = 1e-6f, + bool noncontig_rows = false) + : type(type), ne(ne), v(v), eps(eps), noncontig_rows(noncontig_rows) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + const std::array ne_a = noncontig_rows ? + std::array{ ne[1], ne[0], ne[2], ne[3] } : ne; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); + if (noncontig_rows) { + a = ggml_permute(ctx, a, 1, 0, 2, 3); + ggml_set_name(a, "permuted a"); + } if (v) { a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); @@ -6193,21 +6201,29 @@ struct test_l2_norm : public test_case { const std::array ne; const float eps; bool v; + bool noncontig_rows; std::string vars() override { - return VARS_TO_STR4(type, ne, eps, v); + return VARS_TO_STR5(type, ne, eps, v, noncontig_rows); } test_l2_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 64, 320, 1}, float eps = 1e-12f, - bool v = false) - : type(type), ne(ne), eps(eps), v(v) {} + bool v = false, + bool noncontig_rows = false) + : type(type), ne(ne), eps(eps), v(v), noncontig_rows(noncontig_rows) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + const std::array ne_a = noncontig_rows ? + std::array{ ne[1], ne[0], ne[2], ne[3] } : ne; + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); + if (noncontig_rows) { + a = ggml_permute(ctx, a, 1, 0, 2, 3); + ggml_set_name(a, "permuted a"); + } if (v) { a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); @@ -8282,9 +8298,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); } + test_cases.emplace_back(new test_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, false, eps, true)); test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false, true)); } } @@ -8402,6 +8420,11 @@ static std::vector> make_test_cases_eval() { } } + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_MXFP4, GGML_TYPE_F32, 2880, 32, 2880, {1, 1}, {1, 1})); + + #if 0 { // Test paths in OpenCL @@ -8433,6 +8456,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1})); @@ -8449,6 +8473,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); @@ -8574,6 +8599,7 @@ static std::vector> make_test_cases_eval() { // gpt-oss issue with Vulkan mmq_id test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_MXFP4, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_Q4_0, GGML_TYPE_F32, 32, 2, false, 2880, 32, 2880)); for (ggml_type type_a : all_types) { test_cases.emplace_back(new test_mul_mat_id(type_a, GGML_TYPE_F32, 4, 2, false, 64, 16, 3*ggml_blck_size(type_a))); @@ -9270,6 +9296,34 @@ static std::vector> make_test_cases_perf() { } } + struct conv3d_perf_case { + int N, IC, ID, IH, IW, OC, KD, KH, KW, s0, s1, s2, p0, p1, p2, d0, d1, d2; + }; + + const std::vector conv3d_cases = { + {1, 320, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 1280, 8, 38, 26, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 1280, 8, 76, 52, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, +#if 0 + // too slow on some devices + {1, 1280, 8, 152, 104, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 320, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, + {1, 640, 4, 304, 208, 1280, 3, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1}, +#endif + }; + + for (ggml_type kernel_type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (const conv3d_perf_case & c : conv3d_cases) { + test_cases.emplace_back(new test_conv_3d( + c.N, c.IC, c.ID, c.IH, c.IW, + c.OC, c.KD, c.KH, c.KW, + c.s0, c.s1, c.s2, c.p0, c.p1, c.p2, c.d0, c.d1, c.d2, + kernel_type)); + } + } + test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 1, 1, 1})); test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {4096, 1, 1, 1}, {1, 512, 1, 1})); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 902a4c135a..c38aed8cfe 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() { } } -static void test_split_by_role() { +static void test_msg_token_delimiters_split() { LOG_DBG("%s\n", __func__); + // Delimiters that share a leading token, distinguished by the second token, + // to exercise the per-position token matching. + const common_chat_msg_delimiters delims = { + { { COMMON_CHAT_ROLE_USER, "", { 10, 11 } }, + { COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } } + }; + // Empty inputs - assert_equals(0, common_chat_split_by_role("", {}).size()); - assert_equals(0, common_chat_split_by_role("hello", {}).size()); - assert_equals(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size()); + assert_equals(0, common_chat_msg_delimiters{}.split({}).spans.size()); + assert_equals(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size()); + assert_equals(0, delims.split({}).spans.size()); - // Multi-role conversation, no leading/trailing content + // No delimiters match -> no spans + assert_equals(0, delims.split({ 100, 101, 102 }).spans.size()); + + // Multi-role conversation: HiHelloBye { - const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye"; - const auto splits = common_chat_split_by_role(prompt, { - { "user", "<|user|>" }, - { "assistant", "<|assistant|>" }, - }); - assert_equals(3, splits.size()); + const llama_tokens tokens = { + 10, 11, // + 100, 101, // Hi + 10, 12, // + 200, 201, 202, // Hello + 10, 11, // + 300, 301, // Bye + }; - assert_equals("user", splits[0].role); - assert_equals(0, splits[0].pos); - assert_equals(10, splits[0].len); - assert_equals("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len)); + const auto result = delims.split(tokens); + const auto & spans = result.spans; + assert_equals(3, spans.size()); - assert_equals("assistant", splits[1].role); - assert_equals(10, splits[1].pos); - assert_equals(18, splits[1].len); - assert_equals("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len)); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(4, spans[0].len); - assert_equals("user", splits[2].role); - assert_equals(28, splits[2].pos); - assert_equals(11, splits[2].len); - assert_equals("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len)); + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(4, spans[1].pos); + assert_equals(5, spans[1].len); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role); + assert_equals(9, spans[2].pos); + assert_equals(4, spans[2].len); + + // is_user_start() is true at the token position where a user span begins + assert_equals(true, result.is_user_start(0)); + assert_equals(false, result.is_user_start(4)); // assistant span + assert_equals(true, result.is_user_start(9)); + } + + // Content before the first delimiter is not captured as a span + { + const llama_tokens tokens = { + 500, 501, // leading content (dropped) + 10, 11, // + 100, // Hi + }; + + const auto spans = delims.split(tokens).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(2, spans[0].pos); + assert_equals(3, spans[0].len); + } + + // Skipped regions (media chunks) are jumped over but still count as span content + { + const llama_tokens tokens = { + 10, 11, // + LLAMA_TOKEN_NULL, // media chunk (3 tokens) + LLAMA_TOKEN_NULL, + LLAMA_TOKEN_NULL, + 100, // Hi + 10, 12, // + }; + + const std::map skips = { { 2, 3 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(2, spans.size()); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(6, spans[0].len); + + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(6, spans[1].pos); + assert_equals(2, spans[1].len); + } + + // A delimiter sequence inside a skipped region is not matched + { + const llama_tokens tokens = { + 10, 11, // + 10, 12, // skipped region that happens to contain delimiter tokens + 100, // Hi + }; + + const std::map skips = { { 2, 2 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(5, spans[0].len); } } @@ -5022,14 +5097,14 @@ static void test_template_output_peg_parsers(bool detailed_debug) { tst.test("Hello, world!\nWhat's up?").tools({ special_function_tool }).expect(message_assist).expect_reconstruction().run(); tst.test( - "```json\n\"42\" \n```") + "```json\n\"42\"\n```") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .json_schema(const_schema) .expect_content(R"("42")") .run(); tst.test( - "\"42\" \n") + "\"42\"\n") .reasoning_format(COMMON_REASONING_FORMAT_AUTO) .json_schema(const_schema) .expect_content(R"("42")") @@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) { { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); - test_split_by_role(); + test_msg_token_delimiters_split(); test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 8039956246..81bbcd55a4 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -995,6 +995,32 @@ static void test_macros(testing & t) { json::object(), "Hello, John Smith,Hi, Jane Doe" ); + + test_template(t, "macro with caller", + "\ +{%- macro nest_dict(o, i, ff='') %}\n\ + {{- caller(ff) }}\n\ + {%- for k, v in o|items %}\n\ + {{- i + k + ': ' }}\n\ + {%- if v is mapping %}\n\ + {{- '{' }}\n\ + {% call(f) nest_dict(v, i + ' ') %}\n\ + {{- 'fail' if ff is undefined }}\n\ + {%- endcall %}\n\ + {{- i + '}' }}\n\ + {% else %}\n\ + {{- v|string }}\n\ + {% endif %}\n\ + {%- endfor %}\n\ +{%- endmacro %}\n\ +{%- call(f) nest_dict({'root1': 1, 'root2': {'nest1': 1, 'nest2': {'nest3': 2}}}, ' ', 'Dict') %}\n\ + {{- 'fail' if ff is defined }}\n\ + {{- f + ' {' }}\n\ +{% endcall %}\n\ +{{- '}' }}", + json::object(), + "Dict {\n root1: 1\n root2: {\n nest1: 1\n nest2: {\n nest3: 2\n }\n }\n}" + ); } static void test_namespace(testing & t) { diff --git a/tests/test-json-partial.cpp b/tests/test-json-partial.cpp deleted file mode 100644 index 39da9276ef..0000000000 --- a/tests/test-json-partial.cpp +++ /dev/null @@ -1,287 +0,0 @@ -#include "common.h" -#include "json-partial.h" -#include -#include -#include - -template static void assert_equals(const T & expected, const T & actual) { - if (expected != actual) { - std::cerr << "Expected: " << expected << std::endl; - std::cerr << "Actual: " << actual << std::endl; - std::cerr << std::flush; - throw std::runtime_error("Test failed"); - } -} - -static void test_json_healing() { - auto parse = [](const std::string & str) { - std::cerr << "# Parsing: " << str << '\n'; - std::string::const_iterator it = str.begin(); - const auto end = str.end(); - common_json out; - std::string healing_marker = "$llama.cpp.json$"; - if (common_json_parse(it, end, healing_marker, out)) { - auto dump = out.json.dump(); - std::cerr << "Parsed: " << dump << '\n'; - std::cerr << "Magic: " << out.healing_marker.json_dump_marker << '\n'; - std::string result; - if (!out.healing_marker.json_dump_marker.empty()) { - auto i = dump.find(out.healing_marker.json_dump_marker); - if (i == std::string::npos) { - throw std::runtime_error("Failed to find magic in dump " + dump + " (magic: " + out.healing_marker.json_dump_marker + ")"); - } - result = dump.substr(0, i); - } else { - result = dump; - } - std::cerr << "Result: " << result << '\n'; - if (string_starts_with(str, result)) { - std::cerr << "Failure!\n"; - } - // return dump; - } else { - throw std::runtime_error("Failed to parse: " + str); - } - - }; - auto parse_all = [&](const std::string & str) { - for (size_t i = 1; i < str.size(); i++) { - parse(str.substr(0, i)); - } - }; - parse_all("{\"a\": \"b\"}"); - parse_all("{\"hey\": 1, \"ho\\\"ha\": [1]}"); - - parse_all("[{\"a\": \"b\"}]"); - - auto test = [&](const std::vector & inputs, const std::string & expected, const std::string & expected_marker) { - for (const auto & input : inputs) { - common_json out; - assert_equals(true, common_json_parse(input, "$foo", out)); - assert_equals(expected, out.json.dump(/* indent */ -1, /* indent_char */ ' ', /* ensure_ascii */ true)); - assert_equals(expected_marker, out.healing_marker.json_dump_marker); - } - }; - // No healing needed: - test( - { - R"([{"a":"b"}, "y"])", - }, - R"([{"a":"b"},"y"])", - "" - ); - // Partial literals can't be healed: - test( - { - R"([1)", - R"([tru)", - R"([n)", - R"([nul)", - R"([23.2)", - }, - R"(["$foo"])", - R"("$foo)" - ); - test( - { - R"({"a": 1)", - R"({"a": tru)", - R"({"a": n)", - R"({"a": nul)", - R"({"a": 23.2)", - }, - R"({"a":"$foo"})", - R"("$foo)" - ); - test( - { - R"({)", - }, - R"({"$foo":1})", - R"("$foo)" - ); - test( - { - R"([)", - }, - R"(["$foo"])", - R"("$foo)" - ); - // Healing right after a full literal - test( - { - R"(1 )", - }, - R"(1)", - "" - ); - test( - { - R"(true)", - R"(true )", - }, - R"(true)", - "" - ); - test( - { - R"(null)", - R"(null )", - }, - R"(null)", - "" - ); - test( - { - R"([1 )", - }, - R"([1,"$foo"])", - R"(,"$foo)" - ); - test( - { - R"([{})", - R"([{} )", - }, - R"([{},"$foo"])", - R"(,"$foo)" - ); - test( - { - R"([true)", - }, - // TODO: detect the true/false/null literal was complete - R"(["$foo"])", - R"("$foo)" - ); - test( - { - R"([true )", - }, - R"([true,"$foo"])", - R"(,"$foo)" - ); - test( - { - R"([true,)", - }, - R"([true,"$foo"])", - R"("$foo)" - ); - // Test nesting - test( - { - R"([{"a": [{"b": [{)", - }, - R"([{"a":[{"b":[{"$foo":1}]}]}])", - R"("$foo)" - ); - test( - { - R"([{"a": [{"b": [)", - }, - R"([{"a":[{"b":["$foo"]}]}])", - R"("$foo)" - ); - - test( - { - R"([{"a": "b"})", - R"([{"a": "b"} )", - }, - R"([{"a":"b"},"$foo"])", - R"(,"$foo)" - ); - test( - { - R"([{"a": "b"},)", - R"([{"a": "b"}, )", - }, - R"([{"a":"b"},"$foo"])", - R"("$foo)" - ); - test( - { - R"({ "code)", - }, - R"({"code$foo":1})", - R"($foo)" - ); - test( - { - R"({ "code\)", - }, - R"({"code\\$foo":1})", - R"(\$foo)" - ); - test( - { - R"({ "code")", - }, - R"({"code":"$foo"})", - R"(:"$foo)" - ); - test( - { - R"({ "key")", - }, - R"({"key":"$foo"})", - R"(:"$foo)" - ); - // Test unicode escape sequences - test( - { - R"({"a":"\u)", - }, - R"({"a":"\u0000$foo"})", - R"(0000$foo)" - ); - test( - { - R"({"a":"\u00)", - }, - R"({"a":"\u0000$foo"})", - R"(00$foo)" - ); - test( - { - R"({"a":"\ud300)", - }, - R"({"a":"\ud300$foo"})", - R"($foo)" - ); - test( - { - R"({"a":"\ud800)", - }, - R"({"a":"\ud800\udc00$foo"})", - R"(\udc00$foo)" - ); - test( - { - R"({"a":"\ud800\)", - }, - R"({"a":"\ud800\udc00$foo"})", - R"(udc00$foo)" - ); - test( - { - R"({"a":"\ud800\u)", - }, - R"({"a":"\ud800\udc00$foo"})", - R"(dc00$foo)" - ); - test( - { - R"({"a":"\ud800\udc00)", - }, - R"({"a":"\ud800\udc00$foo"})", - R"($foo)" - ); -} - -int main() { - test_json_healing(); - std::cerr << "All tests passed.\n"; - return 0; -} diff --git a/tests/test-json-schema-to-grammar.cpp b/tests/test-json-schema-to-grammar.cpp index b4362852c3..f095274cd1 100755 --- a/tests/test-json-schema-to-grammar.cpp +++ b/tests/test-json-schema-to-grammar.cpp @@ -92,7 +92,7 @@ static void test_all(const std::string & lang, std::function(env: LLAMA_ARG_MMPROJ_URL) | | `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)
(env: LLAMA_ARG_MMPROJ_AUTO) | | `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)
(env: LLAMA_ARG_MMPROJ_OFFLOAD) | -| `--image, --audio FILE` | path to an image or audio file. use with multimodal models, use comma-separated values for multiple files | +| `--image, --audio, --video FILE` | path to an image, audio, or video file. use with multimodal models, use comma-separated values for multiple files | | `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | | `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | | `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'
(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) | @@ -174,6 +174,7 @@ | `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)
if suffix/prefix are specified, template will be disabled
only commonly used templates are accepted (unless --jinja is set before this flag):
list of built-in templates:
bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek-ocr, deepseek2, deepseek3, exaone-moe, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, granite-4.0, granite-4.1, grok-2, hunyuan-dense, hunyuan-moe, hunyuan-vl, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, solar-open, vicuna, vicuna-orca, yandex, zephyr
(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) | | `--skip-chat-parsing, --no-skip-chat-parsing` | force a pure content parser, even if a Jinja template is specified; model will output everything in the content section, including any reasoning and/or tool calls (default: disabled)
(env: LLAMA_ARG_SKIP_CHAT_PARSING) | | `--simple-io` | use basic IO for better compatibility in subprocesses and limited consoles | +| `--log-prompts-dir PATH` | Log prompts to directory (only used for debugging, default: disabled) | | `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) | | `--spec-draft-threads, -td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `--spec-draft-threads-batch, -tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index c03894b4b1..8b7b58693f 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -202,7 +202,7 @@ struct cli_context { // TODO: support remote files in the future (http, https, etc) std::string load_input_file(const std::string & fname, bool is_media) { - std::ifstream file(fname, std::ios::binary); + std::ifstream file = fs_open_ifstream(fname, std::ios::binary); if (!file) { return ""; } diff --git a/tools/llama-bench/llama-bench.cpp b/tools/llama-bench/llama-bench.cpp index 55970c0745..2695f58785 100644 --- a/tools/llama-bench/llama-bench.cpp +++ b/tools/llama-bench/llama-bench.cpp @@ -1035,25 +1035,23 @@ static cmd_params parse_cmd_params(int argc, char ** argv) { if (!params.hf_repo.empty()) { for (size_t i = 0; i < params.hf_repo.size(); i++) { - common_params_model model; - - if (params.hf_file.empty() || params.hf_file[i].empty()) { - model.hf_repo = params.hf_repo[i]; - } else { - model.hf_repo = params.hf_repo[i]; - model.hf_file = params.hf_file[i]; + common_params p; + p.hf_token = params.hf_token; + p.offline = params.offline; + p.model.hf_repo = params.hf_repo[i]; + if (!params.hf_file.empty() && !params.hf_file[i].empty()) { + p.model.hf_file = params.hf_file[i]; } - common_download_opts opts; - opts.bearer_token = params.hf_token; - opts.offline = params.offline; - auto download_result = common_download_model(model, opts); - if (download_result.model_path.empty()) { + // only the text model file is needed + common_models_handler models_handler = common_models_handler_init(p, LLAMA_EXAMPLE_BENCH); + common_models_handler_apply(models_handler, p); + if (p.model.path.empty()) { fprintf(stderr, "error: failed to download model from HuggingFace\n"); exit(1); } - params.model.push_back(download_result.model_path); + params.model.push_back(p.model.path); } } diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 09b62357f3..ea684d9f15 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -115,22 +115,28 @@ if (TARGET mtmd) endif() endif() -add_executable(llama-llava-cli deprecation-warning.cpp) -add_executable(llama-gemma3-cli deprecation-warning.cpp) -add_executable(llama-minicpmv-cli deprecation-warning.cpp) -add_executable(llama-qwen2vl-cli deprecation-warning.cpp) +# Gate CLI binaries on LLAMA_BUILD_TOOLS so that standalone library-only +# builds (LLAMA_BUILD_MTMD=ON with LLAMA_BUILD_TOOLS=OFF โ€” e.g. Apple +# XCFramework packaging) skip the executables entirely. LLAMA_BUILD_COMMON +# defaults to ON in standalone builds, so we cannot rely on it for gating. +if (LLAMA_BUILD_TOOLS) + add_executable(llama-llava-cli deprecation-warning.cpp) + add_executable(llama-gemma3-cli deprecation-warning.cpp) + add_executable(llama-minicpmv-cli deprecation-warning.cpp) + add_executable(llama-qwen2vl-cli deprecation-warning.cpp) -set(TARGET llama-mtmd-cli) -add_executable (${TARGET} mtmd-cli.cpp) -set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) -if(LLAMA_TOOLS_INSTALL) - install(TARGETS ${TARGET} RUNTIME) + set(TARGET llama-mtmd-cli) + add_executable (${TARGET} mtmd-cli.cpp) + set_target_properties (${TARGET} PROPERTIES OUTPUT_NAME llama-mtmd-cli) + if(LLAMA_TOOLS_INSTALL) + install(TARGETS ${TARGET} RUNTIME) + endif() + target_link_libraries (${TARGET} PRIVATE llama-common mtmd Threads::Threads) + target_compile_features(${TARGET} PRIVATE cxx_std_17) + + # mtmd-debug tool + add_executable(llama-mtmd-debug debug/mtmd-debug.cpp) + set_target_properties(llama-mtmd-debug PROPERTIES OUTPUT_NAME llama-mtmd-debug) + target_link_libraries(llama-mtmd-debug PRIVATE llama-common mtmd Threads::Threads) + target_compile_features(llama-mtmd-debug PRIVATE cxx_std_17) endif() -target_link_libraries (${TARGET} PRIVATE llama-common mtmd Threads::Threads) -target_compile_features(${TARGET} PRIVATE cxx_std_17) - -# mtmd-debug tool -add_executable(llama-mtmd-debug debug/mtmd-debug.cpp) -set_target_properties(llama-mtmd-debug PROPERTIES OUTPUT_NAME llama-mtmd-debug) -target_link_libraries(llama-mtmd-debug PRIVATE llama-common mtmd Threads::Threads) -target_compile_features(llama-mtmd-debug PRIVATE cxx_std_17) diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index f232b68e5a..5b413681f0 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -13,6 +13,14 @@ #include #include #include +#include + +#ifdef _WIN32 +#ifndef NOMINMAX +#define NOMINMAX +#endif +#include +#endif // Internal header for clip.cpp @@ -34,6 +42,7 @@ #define KEY_N_HEAD "clip.%s.attention.head_count" #define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_FEATURE_LAYERS "clip.%s.feature_layer" // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities @@ -46,7 +55,6 @@ #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side" #define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side" @@ -661,6 +669,22 @@ struct clip_image_f32_batch { // common utils // +#ifdef _WIN32 +static std::ifstream open_ifstream_binary(const std::string & fname) { + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, NULL, 0); + if (!wlen) { + throw std::runtime_error("failed to convert filename to UTF-16: " + fname); + } + std::vector wfname(wlen); + (void)MultiByteToWideChar(CP_UTF8, 0, fname.c_str(), -1, wfname.data(), wlen); + return std::ifstream(wfname.data(), std::ios::binary); +} +#else +static std::ifstream open_ifstream_binary(const std::string & fname) { + return std::ifstream(fname, std::ios::binary); +} +#endif + static std::string string_format(const char * fmt, ...) { va_list ap; va_list ap2; diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 08ed0b3412..43fbcc1d5a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -92,7 +92,7 @@ struct clip_hparams { float eps = 1e-6; float rope_theta = 0.0; - std::vector vision_feature_layer; + std::vector feature_layers; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) @@ -166,8 +166,8 @@ struct clip_hparams { return false; } - bool is_vision_feature_layer(int32_t layer) const { - return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end(); + bool is_feature_layer(int32_t layer) const { + return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end(); } }; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 8e2dc64415..4d008f1e6b 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -534,7 +534,7 @@ ggml_tensor * clip_graph::build_vit( ggml_tensor * clip_graph::build_inp() { ggml_tensor * inp_raw = build_inp_raw(); ggml_tensor * inp = ggml_conv_2d(ctx0, model.patch_embeddings_0, inp_raw, patch_size, patch_size, 0, 0, 1, 1); - inp = ggml_reshape_2d(ctx0, inp, n_patches, n_embd); + inp = ggml_reshape_3d(ctx0, inp, n_patches, n_embd, n_batch); inp = ggml_cont(ctx0, ggml_transpose(ctx0, inp)); if (model.patch_bias) { inp = ggml_add(ctx0, inp, model.patch_bias); @@ -1047,8 +1047,17 @@ struct clip_model_loader { bool has_vision = false; bool has_audio = false; + mtmd_progress_callback progress_callback = nullptr; + void * progress_callback_user_data = nullptr; + // TODO @ngxson : we should not pass clip_ctx here, it should be clip_model - clip_model_loader(const char * fname, bool skip_tensors = false) : fname(fname) { + clip_model_loader(const char * fname, + bool skip_tensors = false, + mtmd_progress_callback progress_cb = nullptr, + void * progress_user_data = nullptr) + : fname(fname), + progress_callback(progress_cb), + progress_callback_user_data(progress_user_data) { struct ggml_context * meta = nullptr; struct gguf_init_params params = { @@ -1257,12 +1266,10 @@ struct clip_model_loader { } } - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. + // Load the vision/audio feature layer indices if they are explicitly provided // NOTE: gguf conversions should standardize the values of the vision feature layer to // be non-negative, since we use -1 to mark values as unset here. - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false); + get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false); // model-specific params switch (model.proj_type) { @@ -1653,6 +1660,7 @@ struct clip_model_loader { get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size); get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate); get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count); + // NOTE: feature layers loaded above in common path } break; case PROJECTOR_TYPE_JANUS_PRO: { @@ -1665,11 +1673,11 @@ struct clip_model_loader { hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; hparams.image_resize_pad = PAD_CEIL; - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer); + // NOTE: feature_layers loaded in common path as optional get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets); - if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) { - throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d", - hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size())); + if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) { + throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d", + hparams.feature_layers.size(), hparams.proj_spatial_offsets.size())); } get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side); @@ -1686,6 +1694,9 @@ struct clip_model_loader { // note: some models having hparams.image_size == 0, which means the image size is dynamic throw std::runtime_error(string_format("%s: image_size (%d) cannot be negative\n", __func__, hparams.image_size)); } + if (hparams.image_size > 65536) { + throw std::runtime_error(string_format("%s: image_size (%d) is too large (max 65536)\n", __func__, hparams.image_size)); + } if (hparams.patch_size <= 0) { throw std::runtime_error(string_format("%s: patch_size (%d) must be greater than 0\n", __func__, hparams.patch_size)); } @@ -1734,6 +1745,19 @@ struct clip_model_loader { LOG_INF("%s: audio_n_fft: %d\n", __func__, hparams.audio_n_fft); LOG_INF("%s: audio_window_len: %d\n", __func__, hparams.audio_window_len); LOG_INF("%s: audio_hop_len: %d\n", __func__, hparams.audio_hop_len); + + // GEMMA4UA is encoder-free: it uses n_mel_bins as a raw-waveform frame size (640) and has no FFT/filterbank, so the mel-range and FFT + // checks below do not apply to it. + const bool fft_based = model.proj_type != PROJECTOR_TYPE_GEMMA4UA; + + // Validate audio hparams loaded from GGUF metadata + if (hparams.n_mel_bins <= 0 || (fft_based && hparams.n_mel_bins > 256)) { + throw std::runtime_error(string_format("%s: n_mel_bins (%d) must be in range [1, 256]\n", __func__, hparams.n_mel_bins)); + } + if (fft_based && (hparams.audio_sample_rate <= 0 || hparams.audio_n_fft <= 0 || hparams.audio_hop_len <= 0 || hparams.audio_window_len <= 0)) { + throw std::runtime_error(string_format("%s: audio hparams invalid: sample_rate=%d n_fft=%d window_len=%d hop_len=%d\n", + __func__, hparams.audio_sample_rate, hparams.audio_n_fft, hparams.audio_window_len, hparams.audio_hop_len)); + } } LOG_INF("\n"); LOG_INF("%s: model size: %.2f MiB\n", __func__, model_size / 1024.0 / 1024.0); @@ -1747,7 +1771,7 @@ struct clip_model_loader { std::map tensor_offset; std::vector tensors_to_load; - auto fin = std::ifstream(fname, std::ios::binary); + auto fin = open_ifstream_binary(fname); if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } @@ -2726,7 +2750,7 @@ struct clip_model_loader { model.image_newline = get_tensor(TN_IMAGE_NEWLINE); // Load separate layerwise and spatial projector tensors - const auto projector_count = hparams.vision_feature_layer.size(); + const auto projector_count = hparams.feature_layers.size(); model.qf_proj_blocks.resize(projector_count); for (size_t bid = 0; bid < projector_count; ++bid) { auto & b = model.qf_proj_blocks[bid]; @@ -2782,37 +2806,60 @@ struct clip_model_loader { } // load data - if (!ctx_clip.no_alloc) { + { std::vector read_buf; + // start loading event + if (progress_callback){ + progress_callback(0.0, progress_callback_user_data); + } + + // compute total tensor data size for progress reporting + size_t total_data_size = 0; + for (auto & t : tensors_to_load) { + total_data_size += ggml_nbytes(t); + } + // alloc memory and offload data ggml_backend_buffer_type_t buft = ggml_backend_get_default_buffer_type(ctx_clip.backend); ctx_clip.buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(ctx_clip.ctx_data.get(), buft)); ggml_backend_buffer_set_usage(ctx_clip.buf.get(), GGML_BACKEND_BUFFER_USAGE_WEIGHTS); - for (auto & t : tensors_to_load) { - ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); - GGML_ASSERT(cur && "tensor not found in ctx_data"); - auto it_off = tensor_offset.find(t->name); - GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor"); - const size_t offset = it_off->second; - fin.seekg(offset, std::ios::beg); - if (!fin) { - throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); - } - size_t num_bytes = ggml_nbytes(cur); - if (ggml_backend_buft_is_host(buft)) { - // for the CPU and Metal backend, we can read directly into the tensor - fin.read(reinterpret_cast(cur->data), num_bytes); - } else { - // read into a temporary buffer first, then copy to device memory - read_buf.resize(num_bytes); - fin.read(reinterpret_cast(read_buf.data()), num_bytes); - ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + // read the weight from file + if (!ctx_clip.no_alloc) { + size_t data_loaded = 0; + for (auto & t : tensors_to_load) { + ggml_tensor * cur = ggml_get_tensor(ctx_clip.ctx_data.get(), t->name); + GGML_ASSERT(cur && "tensor not found in ctx_data"); + auto it_off = tensor_offset.find(t->name); + GGML_ASSERT(it_off != tensor_offset.end() && "no offset for tensor"); + const size_t offset = it_off->second; + fin.seekg(offset, std::ios::beg); + if (!fin) { + throw std::runtime_error(string_format("%s: failed to seek for tensor %s\n", __func__, t->name)); + } + size_t num_bytes = ggml_nbytes(cur); + if (ggml_backend_buft_is_host(buft)) { + // for the CPU and Metal backend, we can read directly into the tensor + fin.read(reinterpret_cast(cur->data), num_bytes); + } else { + // read into a temporary buffer first, then copy to device memory + read_buf.resize(num_bytes); + fin.read(reinterpret_cast(read_buf.data()), num_bytes); + ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); + } + data_loaded += num_bytes; + if (progress_callback && total_data_size > 0) { + const float progress = (float)data_loaded / (float)total_data_size; + if (!progress_callback(progress, progress_callback_user_data)) { + throw std::runtime_error(string_format("%s: model loading cancelled by progress_callback\n", __func__)); + } + } } + LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); + } else { + LOG_DBG("%s: no_alloc is set, skipping tensor data loading (%zu tensors)\n", __func__, tensors_to_load.size()); } fin.close(); - - LOG_DBG("%s: loaded %zu tensors from %s\n", __func__, tensors_to_load.size(), fname.c_str()); } } @@ -2842,6 +2889,12 @@ struct clip_model_loader { img.set_size({sz, sz}, false, false); LOG_INF("%s: warmup with image size = %d x %d\n", __func__, sz, sz); } else { + // GEMMA4UA uses n_mel_bins as a raw-waveform frame size (640), not a mel-bin count, + // so the [1, 256] bound only applies to FFT-based models. + const bool fft_based = ctx_clip.model.proj_type != PROJECTOR_TYPE_GEMMA4UA; + if (hparams.n_mel_bins <= 0 || (fft_based && hparams.n_mel_bins > 256)) { + throw std::runtime_error(string_format("%s: invalid n_mel_bins (%d), must be in [1, 256]\n", __func__, hparams.n_mel_bins)); + } img.set_size({hparams.warmup_audio_size, hparams.n_mel_bins}, false, false); LOG_INF("%s: warmup with audio size = %d\n", __func__, hparams.warmup_audio_size); } @@ -3005,7 +3058,13 @@ struct clip_model_loader { } return; } - output = gguf_get_val_u32(ctx_gguf.get(), i); + const uint32_t val = gguf_get_val_u32(ctx_gguf.get(), i); + // sanity check + if (val > (uint32_t) INT32_MAX) { + throw std::runtime_error(string_format("%s: value %u for key '%s' exceeds INT32_MAX\n", + __func__, val, key.c_str())); + } + output = (int) val; } void get_f32(const std::string & key, float & output, bool required = true) const { @@ -3088,7 +3147,10 @@ struct clip_init_result clip_init(const char * fname, struct clip_context_params clip_ctx * ctx_audio = nullptr; try { - clip_model_loader loader(fname); + clip_model_loader loader(fname, + /* skip_tensors */ false, + ctx_params.progress_callback, + ctx_params.progress_callback_user_data); bool skip_audio = false; if (loader.has_vision) { @@ -4349,7 +4411,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32 // Stage 1b only uses block 0's permutations; future stages // will upload all blocks. - for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) { + for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) { const std::string prefix = "g4v_blk" + std::to_string(bid) + "_"; upload(prefix + "win_idx", make_win_idx(image_side, window_side)); upload(prefix + "qwin_idx", make_win_idx(new_side, query_side)); diff --git a/tools/mtmd/clip.h b/tools/mtmd/clip.h index f66f4bc3bb..967093a812 100644 --- a/tools/mtmd/clip.h +++ b/tools/mtmd/clip.h @@ -24,6 +24,9 @@ struct clip_image_size { return !(*this == other); } int area() const { + // avoid overflow when computing area + GGML_ASSERT(width >= 0 && width <= 46000); + GGML_ASSERT(height >= 0 && height <= 46000); return width * height; } }; @@ -51,6 +54,8 @@ struct clip_context_params { ggml_backend_sched_eval_callback cb_eval; void * cb_eval_user_data; bool no_alloc; + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; struct clip_init_result { diff --git a/tools/mtmd/models/granite-speech.cpp b/tools/mtmd/models/granite-speech.cpp index 0bd4d75ac5..a158a59ce9 100644 --- a/tools/mtmd/models/granite-speech.cpp +++ b/tools/mtmd/models/granite-speech.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include + ggml_cgraph * clip_graph_granite_speech::build() { const int n_frames = img.nx(); const int context_size = hparams.audio_chunk_size; @@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() { const int padded_len = num_blocks * context_size; const int remainder = n_frames % context_size; + // Calculate projector input dimension based on feature layers + const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1); + const bool use_feature_concat = !hparams.feature_layers.empty(); + ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size); ggml_set_name(attn_dists, "attn_dists"); ggml_set_input(attn_dists); @@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_add(ctx0, cur, model.inp_proj_b); cb(cur, "inp_linear", -1); + // Capture layer 0 if requested (after input_linear) + ggml_tensor * concat_result = nullptr; + if (use_feature_concat) { + if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) { + concat_result = cur; + cb(concat_result, "feature_layer_0", -1); + } + } + for (int il = 0; il < n_layer; il++) { const auto & layer = model.layers[il]; auto * residual = cur; @@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() { NORM_TYPE_NORMAL, eps, il); cb(cur, "layer_out", il); + // Capture intermediate layer (il + 1) if requested + if (use_feature_concat) { + if (hparams.is_feature_layer(il + 1)) { + if (concat_result == nullptr) { + concat_result = cur; + } else { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + } + cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il); + } + } + // CTC branch if (il + 1 == ctc_layer) { auto * mid = build_mm(model.ctc_out_w, cur); @@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() { } } + // Append final output to concatenated features if using feature concatenation + if (use_feature_concat && concat_result != nullptr) { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + cb(concat_result, "concat_final", -1); + cur = concat_result; + } + cb(cur, "encoder_out", -1); // QFormer projector @@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0); } - ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj); + ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj); ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query, model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b, diff --git a/tools/mtmd/models/granite4-vision.cpp b/tools/mtmd/models/granite4-vision.cpp index 9adb6f0fdb..1b252543c0 100644 --- a/tools/mtmd/models/granite4-vision.cpp +++ b/tools/mtmd/models/granite4-vision.cpp @@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() { } // --- Stage 1b/1c: WindowQFormer blocks --- - const int projector_count = hparams.vision_feature_layer.size(); + const int projector_count = hparams.feature_layers.size(); const float qformer_eps = 1e-12f; ggml_tensor * mmproj = nullptr; for (int bid = 0; bid < projector_count; ++bid) { const auto & blk = model.qf_proj_blocks[bid]; - int vlayer = hparams.vision_feature_layer[bid]; + int vlayer = hparams.feature_layers[bid]; GGML_ASSERT(vlayer >= 0 && vlayer < n_layer); ggml_tensor * h = layer_outs[vlayer]; diff --git a/tools/mtmd/models/internvl.cpp b/tools/mtmd/models/internvl.cpp index 9aded3b97c..65d7d5a6b7 100644 --- a/tools/mtmd/models/internvl.cpp +++ b/tools/mtmd/models/internvl.cpp @@ -8,7 +8,9 @@ ggml_cgraph * clip_graph_internvl::build() { ggml_tensor * inp = build_inp(); // add CLS token - inp = ggml_concat(ctx0, inp, model.class_embedding, 1); + ggml_tensor * cls_repeated = ggml_repeat_4d(ctx0, model.class_embedding, + model.class_embedding->ne[0], 1, n_batch, 1); + inp = ggml_concat(ctx0, inp, cls_repeated, 1); // The larger models use a different ViT, which uses RMS norm instead of layer norm // ref: https://github.com/ggml-org/llama.cpp/pull/13443#issuecomment-2869786188 @@ -24,14 +26,15 @@ ggml_cgraph * clip_graph_internvl::build() { nullptr); // remove CLS token - cur = ggml_view_2d(ctx0, cur, - n_embd, n_patches, - ggml_row_size(cur->type, n_embd), 0); + cur = ggml_view_3d(ctx0, cur, + n_embd, n_patches, n_batch, + cur->nb[1], cur->nb[2], 0); + cur = ggml_cont(ctx0, cur); // pixel shuffle { const int scale_factor = model.hparams.n_merge; - const int bsz = 1; // batch size, always 1 for now since we don't support batching + const int bsz = n_batch; const int height = n_patches_y; const int width = n_patches_x; GGML_ASSERT(scale_factor > 0); @@ -44,9 +47,10 @@ ggml_cgraph * clip_graph_internvl::build() { bsz); cur = ggml_permute(ctx0, cur, 0, 2, 1, 3); // flatten to 2D - cur = ggml_cont_2d(ctx0, cur, + cur = ggml_cont_3d(ctx0, cur, n_embd * scale_factor * scale_factor, - cur->ne[1] * cur->ne[2]); + cur->ne[1] * cur->ne[2], + cur->ne[3]); } // projector (always using GELU activation) diff --git a/tools/mtmd/models/llava.cpp b/tools/mtmd/models/llava.cpp index 5aa3d2f0fa..47efe68bd8 100644 --- a/tools/mtmd/models/llava.cpp +++ b/tools/mtmd/models/llava.cpp @@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() { // If we set explicit vision feature layers, only go up to the deepest one // NOTE: only used by granite-vision models for now - for (const auto & feature_layer : hparams.vision_feature_layer) { + for (const auto & feature_layer : hparams.feature_layers) { if (feature_layer > deepest_feature_layer) { deepest_feature_layer = feature_layer; } @@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() { // If this is an embedding feature layer, save the output. // NOTE: 0 index here refers to the input to the encoder. - if (hparams.is_vision_feature_layer(il)) { + if (hparams.is_feature_layer(il)) { embedding_stack.push_back(cur); } @@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() { // process vision feature layers (used by granite) { // final layer is a vision feature layer - if (hparams.is_vision_feature_layer(max_feature_layer)) { + if (hparams.is_feature_layer(max_feature_layer)) { embedding_stack.push_back(inpL); } diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 3d2fa4073f..5f1493fa60 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -80,6 +80,7 @@ struct clip_graph_minicpmv4_6 : clip_graph { struct clip_graph_internvl : clip_graph { clip_graph_internvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; + bool support_batch() const override { return true; } }; struct clip_graph_nemotron_v2_vl : clip_graph { diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index 13f211fd90..b72fd067a5 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -32,8 +32,8 @@ void mtmd_audio_cache::fill_hann_window(uint32_t length, bool periodic) { } } -void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, - int n_fft, +void mtmd_audio_cache::fill_mel_filterbank_matrix(int64_t n_mel, + int64_t n_fft, int sample_rate, float fmin, float fmax, @@ -86,11 +86,16 @@ void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, hz_pts[i] = mel_to_hz(mel_pts[i]); } - const int n_fft_bins = n_fft / 2 + 1; + const int64_t n_fft_bins = n_fft / 2 + 1; + + // Validate allocation size + if ((size_t)n_mel * (size_t)n_fft_bins > SIZE_MAX) { + GGML_ASSERT(false && "mel filterbank allocation too large"); + } // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { + std::vector out((size_t)n_mel * (size_t)n_fft_bins, 0); + for (int64_t m = 0; m < n_mel; ++m) { const double f_left = hz_pts[m]; const double f_center = hz_pts[m + 1]; const double f_right = hz_pts[m + 2]; @@ -266,8 +271,8 @@ static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) } struct filter_params { - int32_t n_mel; - int32_t n_fft_bins; + int64_t n_mel; + int64_t n_fft_bins; int32_t hann_window_size; int32_t hop_length; int32_t sample_rate; @@ -293,8 +298,8 @@ static void log_mel_spectrogram_worker_thread(int ith, std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); - int n_fft_bins = params.n_fft_bins; - int i = ith; + int64_t n_fft_bins = params.n_fft_bins; + int64_t i = ith; const auto & filters = cache.filters; @@ -302,17 +307,18 @@ static void log_mel_spectrogram_worker_thread(int ith, GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero - for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { - const int offset = i * frame_step; + for (; i < std::min((int64_t)(n_samples / frame_step + 1), out.n_len); i += n_threads) { + const int64_t offset = i * frame_step; // apply Hann window (~10% faster) - for (int j = 0; j < std::min(frame_size, n_samples - offset); j++) { + const int valid_len = std::min(frame_size, std::max(0, n_samples - (int)offset)); + for (int j = 0; j < valid_len; j++) { fft_in[j] = hann[j] * samples[offset + j]; } // fill the rest with zeros - if (n_samples - offset < frame_size) { - std::fill(fft_in.begin() + (n_samples - offset), fft_in.end(), 0.0); + if (valid_len < frame_size) { + std::fill(fft_in.begin() + valid_len, fft_in.end(), 0.0); } // FFT @@ -325,7 +331,7 @@ static void log_mel_spectrogram_worker_thread(int ith, } // mel spectrogram - for (int j = 0; j < out.n_mel; j++) { + for (int64_t j = 0; j < out.n_mel; j++) { double sum = 0.0; // unroll loop (suggested by GH user @lunixbochs) int k = 0; @@ -339,21 +345,21 @@ static void log_mel_spectrogram_worker_thread(int ith, } // handle n_fft remainder for (; k < n_fft_bins; k++) { - sum += fft_out[k] * filters.data[j * n_fft_bins + k]; + sum += fft_out[k] * filters.data[(size_t)j * n_fft_bins + k]; } sum = std::max(sum, (double)params.mel_floor); sum = params.use_natural_log ? log(sum) : log10(sum); - out.data[j * out.n_len + i] = sum; + out.data[(size_t)j * out.n_len + i] = sum; } } // Otherwise fft_out are all zero double sum = params.use_natural_log ? log(1e-10) : log10(1e-10); for (; i < out.n_len; i += n_threads) { - for (int j = 0; j < out.n_mel; j++) { - out.data[j * out.n_len + i] = sum; + for (int64_t j = 0; j < out.n_mel; j++) { + out.data[(size_t)j * out.n_len + i] = sum; } } } @@ -437,16 +443,21 @@ static bool log_mel_spectrogram( GGML_ASSERT(params.hop_length > 0); out.n_mel = params.n_mel; out.n_len = (n_samples - frame_size) / frame_step + 1; - // TODO: handle these checks better - if (out.n_mel > 0 && (unsigned long)out.n_len > SIZE_MAX / out.n_mel) { - LOG_ERR("%s: size overflow\n", __func__); + // Validate dimensions before allocation to prevent integer overflow + if (out.n_mel <= 0 || out.n_len <= 0) { + LOG_ERR("%s: invalid mel dimensions n_mel=%lld n_len=%lld\n", __func__, (long long)out.n_mel, (long long)out.n_len); + return false; + } + const size_t total_size = (size_t)out.n_mel * (size_t)out.n_len; + if (total_size > SIZE_MAX / sizeof(float)) { + LOG_ERR("%s: size overflow: n_mel=%lld n_len=%lld\n", __func__, (long long)out.n_mel, (long long)out.n_len); return false; } if (n_samples < frame_size) { LOG_ERR("%s: not enough samples after padding\n", __func__); return false; } - out.data.resize(out.n_mel * out.n_len); + out.data.resize(total_size); { std::vector workers(n_threads - 1); @@ -464,38 +475,39 @@ static bool log_mel_spectrogram( } } - const int effective_n_len = n_samples_in / frame_step; + const int64_t effective_n_len = n_samples_in / frame_step; if (params.norm_per_feature) { GGML_ASSERT(effective_n_len > 1); - for (int i = 0; i < out.n_mel; i++) { + for (int64_t i = 0; i < out.n_mel; i++) { double mean = 0; - for (int j = 0; j < effective_n_len; ++j) { - mean += out.data[i * out.n_len + j]; + for (int64_t j = 0; j < effective_n_len; ++j) { + mean += out.data[(size_t)i * out.n_len + j]; } mean /= effective_n_len; double var = 0.0; - for (int j = 0; j < effective_n_len; ++j) { - const double value = out.data[i * out.n_len + j] - mean; + for (int64_t j = 0; j < effective_n_len; ++j) { + const double value = out.data[(size_t)i * out.n_len + j] - mean; var += value * value; } var /= effective_n_len - 1; // unbiased const double mstd = std::sqrt(var + 1e-5); - for (int j = 0; j < effective_n_len; ++j) { - auto &value = out.data[i * out.n_len + j]; + for (int64_t j = 0; j < effective_n_len; ++j) { + auto &value = out.data[(size_t)i * out.n_len + j]; value = (value - mean) / mstd; } // pad the rest with zeros - for (int j = effective_n_len; j < out.n_len; ++j) { - out.data[i * out.n_len + j] = 0.0; + for (int64_t j = effective_n_len; j < out.n_len; ++j) { + out.data[(size_t)i * out.n_len + j] = 0.0; } } } else if (!params.no_padding) { // Whisper-style clamping and normalization (NOT used by Gemma4) double mmax = -1e20; - for (int i = 0; i < out.n_mel*out.n_len; i++) { + const size_t mel_size = (size_t)out.n_mel * (size_t)out.n_len; + for (size_t i = 0; i < mel_size; i++) { if (out.data[i] > mmax) { mmax = out.data[i]; } @@ -503,7 +515,7 @@ static bool log_mel_spectrogram( mmax -= 8.0; - for (int i = 0; i < out.n_mel*out.n_len; i++) { + for (size_t i = 0; i < mel_size; i++) { if (out.data[i] < mmax) { out.data[i] = mmax; } @@ -582,13 +594,13 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s // because the cgraph in clip.cpp only accepts 3000 frames each, we need to split the mel // we always expect the mel to have 3000 silent frames at the end if (DEBUG) { - printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); + printf("output: n_mel = %d, n_len = %d\n", (int) out_full.n_mel, (int) out_full.n_len); } const size_t frames_per_chunk = 3000; GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); - if ((size_t) n_len < frames_per_chunk) { + int64_t n_len = std::min((int64_t)frames_per_chunk, out_full.n_len - (int64_t)off); + if (n_len < (int64_t)frames_per_chunk) { break; // last incomplete chunk will always be a padded chunk, safe to ignore } @@ -596,10 +608,10 @@ bool mtmd_audio_preprocessor_whisper::preprocess(const float * s out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; out_chunk.n_len_org = out_full.n_mel; // unused - out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); + out_chunk.data.reserve((size_t)out_chunk.n_mel * (size_t)out_chunk.n_len); - for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i * out_full.n_len + off; + for (int64_t i = 0; i < out_full.n_mel; i++) { + auto src = out_full.data.begin() + (size_t)i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -681,8 +693,8 @@ bool mtmd_audio_preprocessor_qwen3a::preprocess(const float * sa // The effective frame count: center-padded STFT gives ~n_samples/hop_length frames. // We take min(mel_full.n_len, n_samples/hop + 1) to avoid including excess frames. - const int n_eff = std::min(mel_full.n_len, - (int)(n_samples / hparams.audio_hop_len) + 1); + const int64_t n_eff = std::min(mel_full.n_len, + (int64_t)(n_samples / hparams.audio_hop_len) + 1); // Split into inference windows matching n_window_infer=800 from model config. // Each window is padded to the next multiple of chunk_size for the cgraph. @@ -690,18 +702,18 @@ bool mtmd_audio_preprocessor_qwen3a::preprocess(const float * sa const int chunk_size = 100; // conv sub-chunk size (n_window * 2, n_window=50) const int window_size = 800; // mel frames per forward pass (n_window_infer=800) - for (int off = 0; off < n_eff; off += window_size) { - const int win_eff = std::min(window_size, n_eff - off); - const int n_chunks = (win_eff + chunk_size - 1) / chunk_size; - const int n_padded = n_chunks * chunk_size; + for (int64_t off = 0; off < n_eff; off += window_size) { + const int64_t win_eff = std::min((int64_t)window_size, n_eff - off); + const int64_t n_chunks = (win_eff + chunk_size - 1) / chunk_size; + const int64_t n_padded = n_chunks * chunk_size; mtmd_audio_mel out; out.n_mel = mel_full.n_mel; out.n_len = n_padded; out.n_len_org = win_eff; - out.data.assign(out.n_mel * out.n_len, 0.0f); - for (int m = 0; m < out.n_mel; m++) { - const int copy_len = std::min(win_eff, mel_full.n_len - off); + out.data.assign((size_t)out.n_mel * (size_t)out.n_len, 0.0f); + for (int64_t m = 0; m < out.n_mel; m++) { + const int64_t copy_len = std::min((int64_t)win_eff, mel_full.n_len - off); if (copy_len > 0) { std::copy(mel_full.data.begin() + (size_t)m * mel_full.n_len + off, mel_full.data.begin() + (size_t)m * mel_full.n_len + off + copy_len, @@ -823,37 +835,38 @@ bool mtmd_audio_preprocessor_granite_speech::preprocess(const float * } double mmax = -1e20; - for (int i = 0; i < mel.n_mel * mel.n_len; i++) { + const size_t mel_size = (size_t)mel.n_mel * (size_t)mel.n_len; + for (size_t i = 0; i < mel_size; i++) { if (mel.data[i] > mmax) { mmax = mel.data[i]; } } mmax -= 8.0; - for (int i = 0; i < mel.n_mel * mel.n_len; i++) { + for (size_t i = 0; i < mel_size; i++) { if (mel.data[i] < mmax) { mel.data[i] = mmax; } mel.data[i] = (mel.data[i] + 4.0) / 4.0; } - int n_frames = mel.n_len; + int64_t n_frames = mel.n_len; if (n_frames % 2 == 1) { n_frames--; } - const int n_mel = mel.n_mel; - const int n_stacked = n_frames / 2; + const int64_t n_mel = mel.n_mel; + const int64_t n_stacked = n_frames / 2; mtmd_audio_mel stacked; stacked.n_mel = 2 * n_mel; stacked.n_len = n_stacked; - stacked.n_len_org = (int)n_samples; - stacked.data.resize(2 * n_mel * n_stacked); + stacked.n_len_org = (int64_t)n_samples; + stacked.data.resize((size_t)2 * (size_t)n_mel * (size_t)n_stacked); - for (int t = 0; t < n_stacked; t++) { - for (int m = 0; m < n_mel; m++) { - stacked.data[m * n_stacked + t] = mel.data[m * mel.n_len + 2 * t]; - stacked.data[(m + n_mel) * n_stacked + t] = mel.data[m * mel.n_len + 2 * t + 1]; + for (int64_t t = 0; t < n_stacked; t++) { + for (int64_t m = 0; m < n_mel; m++) { + stacked.data[(size_t)m * n_stacked + t] = mel.data[(size_t)m * mel.n_len + 2 * t]; + stacked.data[(size_t)(m + n_mel) * n_stacked + t] = mel.data[(size_t)m * mel.n_len + 2 * t + 1]; } } @@ -921,8 +934,8 @@ bool mtmd_audio_preprocessor_gemma4a::preprocess(const float * s const int hop = hparams.audio_hop_len; const int n_with_left = (int)chunk_len + pad_left; // PyTorch: unfold(size=frame_length+1, step=hop) on semicausal-padded waveform - const int pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1; - const int n_padded_needed = (pt_frames - 1) * hop + fft_size; + const int64_t pt_frames = (n_with_left - (hparams.audio_window_len + 1)) / hop + 1; + const int64_t n_padded_needed = (pt_frames - 1) * hop + fft_size; const int total_pad = std::max((int)(n_padded_needed - (int)chunk_len), pad_left); std::vector padded_samples(total_pad + chunk_len, 0.0f); std::copy(chunk_ptr, chunk_ptr + chunk_len, padded_samples.data() + pad_left); diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index 9656e3940f..ad96bd847c 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -10,16 +10,16 @@ #define MTMD_INTERNAL_HEADER struct mtmd_audio_mel { - int n_len; - int n_len_org; - int n_mel; + int64_t n_len; + int64_t n_len_org; + int64_t n_mel; std::vector data; }; struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; + int64_t n_mel; + int64_t n_fft; std::vector data; }; @@ -39,8 +39,8 @@ struct mtmd_audio_cache { // Build mel filterbank matrix [n_mel ร— n_fft_bins] at runtime. // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix(int n_mel, - int n_fft, + void fill_mel_filterbank_matrix(int64_t n_mel, + int64_t n_fft, int sample_rate, // e.g. 16000 float fmin = 0.0f, // e.g. 0.0 float fmax = -1.0f, // e.g. sr/2; pass -1 for auto diff --git a/tools/mtmd/mtmd-cli.cpp b/tools/mtmd/mtmd-cli.cpp index a3cad7cd06..8704ea79d7 100644 --- a/tools/mtmd/mtmd-cli.cpp +++ b/tools/mtmd/mtmd-cli.cpp @@ -32,9 +32,9 @@ static volatile bool g_is_generating = false; static volatile bool g_is_interrupted = false; /** - * Please note that this is NOT a production-ready stuff. + * Please note that this is NOT a production-ready binary. * It is a playground for trying multimodal support in llama.cpp. - * For contributors: please keep this code simple and easy to understand. + * For contributors: please keep this code simple and easy to understand. Do not add unnecessary complexity. The goal is to have a simple CLI for testing multimodal support. */ static void show_additional_info(int /*argc*/, char ** argv) { @@ -65,6 +65,14 @@ static void sigint_handler(int signo) { } #endif +// this is only used by tests.sh to capture the response ; it's not meant to be used in production +static void inject_test_response_marker() { + const char * env = std::getenv("MTMD_TEST_RESPONSE_MARKER"); + if (env) { + LOG("%s\n", env); + } +} + struct mtmd_cli_context { mtmd::context_ptr ctx_vision; common_init_result_ptr llama_init; @@ -79,6 +87,8 @@ struct mtmd_cli_context { mtmd::bitmaps bitmaps; std::vector videos; + mtmd::batch_ptr mbatch; + // chat template common_chat_templates_ptr tmpls; std::vector chat_history; @@ -233,6 +243,8 @@ static std::string chat_add_and_format(mtmd_cli_context & ctx, common_chat_msg & } static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { + inject_test_response_marker(); + bool add_bos = ctx.chat_history.empty(); auto formatted_chat = chat_add_and_format(ctx, msg); LOG_DBG("formatted_chat.prompt: %s\n", formatted_chat.c_str()); @@ -259,20 +271,95 @@ static int eval_message(mtmd_cli_context & ctx, common_chat_msg & msg) { ctx.bitmaps.entries.clear(); ctx.videos.clear(); - llama_pos new_n_past; - if (mtmd_helper_eval_chunks(ctx.ctx_vision.get(), - ctx.lctx, // lctx - chunks.ptr.get(), // chunks - ctx.n_past, // n_past - 0, // seq_id - ctx.n_batch, // n_batch - true, // logits_last - &new_n_past)) { - LOG_ERR("Unable to eval prompt\n"); - return 1; - } + // batch encode all media chunks, then decode each + size_t n_chunks = mtmd_input_chunks_size(chunks.ptr.get()); + for (size_t i = 0; i < n_chunks; i++) { + auto chunk = mtmd_input_chunks_get(chunks.ptr.get(), i); + auto chunk_type = mtmd_input_chunk_get_type(chunk); - ctx.n_past = new_n_past; + if (chunk_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + // decode text chunk + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_eval_chunk_single(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + i == n_chunks - 1, // logits_last + &new_n_past); + if (res != 0) { + LOG_ERR("Unable to eval text chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } else { + // media chunk: try to get embd from existing batch, or create a new batch + float * embd = nullptr; + if (ctx.mbatch) { + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + + if (embd) { + LOG_DBG("found embd for media chunk %zu in existing batch\n", i); + } else { + LOG_DBG("media chunk %zu not found in existing batch, creating new batch\n", i); + } + } + + if (!embd) { + // create and encode a new batch with as many media chunks as possible + ctx.mbatch.reset(mtmd_batch_init(ctx.ctx_vision.get())); + res = mtmd_batch_add_chunk(ctx.mbatch.get(), chunk); + GGML_ASSERT(res == 0); // first chunk must always succeed + + int n_added = 1; + // add as many subsequent media chunks as possible + for (size_t j = i + 1; j < n_chunks; j++) { + auto next_chunk = mtmd_input_chunks_get(chunks.ptr.get(), j); + auto next_type = mtmd_input_chunk_get_type(next_chunk); + if (next_type == MTMD_INPUT_CHUNK_TYPE_TEXT) { + break; // text chunk splits the batch + } + res = mtmd_batch_add_chunk(ctx.mbatch.get(), next_chunk); + if (res != 0) { + break; // batch full or incompatible + } + n_added++; + } + + int64_t time_start = ggml_time_ms(); + LOG_INF("encoding mtmd batch, n_chunks = %d (done = %zu, total = %zu)\n", n_added, i, n_chunks); + res = mtmd_batch_encode(ctx.mbatch.get()); + if (res != 0) { + LOG_ERR("Failed to encode mtmd batch, res = %d\n", res); + return 1; + } + LOG_INF("mtmd batch encoding done in %d ms\n", (int)(ggml_time_ms() - time_start)); + + embd = mtmd_batch_get_output_embd(ctx.mbatch.get(), chunk); + } + + GGML_ASSERT(embd != nullptr); + + llama_pos new_n_past = ctx.n_past; + res = mtmd_helper_decode_image_chunk(ctx.ctx_vision.get(), + ctx.lctx, + chunk, + embd, + ctx.n_past, + 0, // seq_id + ctx.n_batch, + &new_n_past, + nullptr, // callback + nullptr // user_data + ); + if (res != 0) { + LOG_ERR("Unable to decode media chunk %zu\n", i); + return 1; + } + ctx.n_past = new_n_past; + } + } LOG("\n"); @@ -309,6 +396,9 @@ int main(int argc, char ** argv) { int n_predict = params.n_predict < 0 ? INT_MAX : params.n_predict; + console::init(params.simple_io, params.use_color); + atexit([]() { console::cleanup(); }); + // Ctrl+C handling { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index b5c4089232..3c73db4431 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -582,13 +582,29 @@ mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_buf(mtmd_context * ctx, } mtmd_helper_bitmap_wrapper mtmd_helper_bitmap_init_from_file(mtmd_context * ctx, const char * fname, bool placeholder) { - std::vector buf; +#ifdef _WIN32 + int wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, NULL, 0); + if (!wlen) { + LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname); + return {nullptr, nullptr}; + } + std::vector wfname(wlen); + wlen = MultiByteToWideChar(CP_UTF8, 0, fname, -1, wfname.data(), wlen); + if (!wlen) { + LOG_ERR("Unable to convert filename to UTF-16: %s\n", fname); + return {nullptr, nullptr}; + } + FILE * f = _wfopen(wfname.data(), L"rb"); +#else FILE * f = fopen(fname, "rb"); +#endif if (!f) { LOG_ERR("Unable to open file %s: %s\n", fname, strerror(errno)); return {nullptr, nullptr}; } + std::vector buf; + fseek(f, 0, SEEK_END); long file_size = ftell(f); fseek(f, 0, SEEK_SET); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d063dc8a8e..e51de8fbd3 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -251,6 +251,8 @@ mtmd_context_params mtmd_context_params_default() { /* cb_eval */ nullptr, /* cb_eval_user_data */ nullptr, /* batch_max_tokens */ 1024, + /* progress_callback */ nullptr, + /* progress_callback_user_data */ nullptr, }; return params; } @@ -345,6 +347,8 @@ struct mtmd_context { /* cb_eval */ ctx_params.cb_eval, /* cb_eval_user_data */ ctx_params.cb_eval_user_data, /* no_alloc */ no_alloc, + /* progress_callback */ ctx_params.progress_callback, + /* progress_callback_user_data */ ctx_params.progress_callback_user_data, }; auto res = clip_init(mmproj_fname, ctx_clip_params); @@ -1296,9 +1300,12 @@ struct mtmd_tokenizer { for (auto & mel_spec : mel_spec_chunks) { const bool is_placeholder = mel_spec.data.empty(); + // Validate dimensions fit in clip_image_size (int) + GGML_ASSERT(mel_spec.n_len <= INT32_MAX && mel_spec.n_len >= 0); + GGML_ASSERT(mel_spec.n_mel <= INT32_MAX && mel_spec.n_mel >= 0); clip_image_f32 mel_f32; mel_f32.set_size( - {mel_spec.n_len, mel_spec.n_mel}, + {(int)mel_spec.n_len, (int)mel_spec.n_mel}, is_placeholder, /* is_audio */ true); mel_f32.cpy_buf(mel_spec.data); @@ -2131,9 +2138,12 @@ std::map mtmd_get_memory_usage(const char * mmproj_f mtmd::context_ptr ctx; auto saved_log_callback = g_logger_state.log_callback; auto saved_log_user_data = g_logger_state.log_callback_user_data; + + ctx_params.progress_callback = nullptr; + try { mtmd_log_set(stub_log_callback, nullptr); // suppress logging - ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params)); + ctx.reset(new mtmd_context(mmproj_fname, nullptr, ctx_params, true)); mtmd_log_set(saved_log_callback, saved_log_user_data); // restore log callback std::map total_mem; auto merge = [&](const struct clip_ctx * c) { diff --git a/tools/mtmd/mtmd.h b/tools/mtmd/mtmd.h index 2fd149e480..25d51ef58d 100644 --- a/tools/mtmd/mtmd.h +++ b/tools/mtmd/mtmd.h @@ -83,6 +83,8 @@ typedef struct mtmd_input_chunks mtmd_input_chunks; typedef struct mtmd_input_text mtmd_input_text; typedef struct mtmd_batch mtmd_batch; +typedef bool (*mtmd_progress_callback)(float progress, void * user_data); + struct mtmd_context_params { bool use_gpu; bool print_timings; @@ -104,6 +106,12 @@ struct mtmd_context_params { int32_t batch_max_tokens; // maximum number of output tokens in a batch // (note: this is not a hard-limit, the first image will always be added even if it exceeds this limit) // (default: 1024) + + // Called with a progress value between 0.0 and 1.0. Pass NULL to disable. + // If the provided progress_callback returns true, model loading continues. + // If it returns false, model loading is immediately aborted. + mtmd_progress_callback progress_callback; + void * progress_callback_user_data; }; MTMD_API const char * mtmd_default_marker(void); diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 5da48d61bf..6fe26478ab 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -13,6 +13,8 @@ mkdir -p $SCRIPT_DIR/output PROJ_ROOT="$SCRIPT_DIR/../.." cd $PROJ_ROOT +export MTMD_TEST_RESPONSE_MARKER="" + # Check if the first argument is "big", then run test with big models # This is useful if we're running the script on a larger machine, so we can test the big models RUN_BIG_TESTS=false @@ -28,6 +30,15 @@ if [ "${1:-}" = "huge" ]; then echo "Include BIG and HUGE models..." fi +USE_VIDEO=false +if [ "${1:-}" = "video" ]; then + USE_VIDEO=true + echo "Using video as input..." + # behavior of USE_VIDEO: + # do NOT check if the output contains "new york", only verify if the exit code is 0 + # when printing the result, print the OK/FAIL line then print the generated text +fi + # Check if the second argument is "flash", then enable flash attention # This is useful to test if flash attention off works correctly FLASH_ATTN="on" @@ -50,13 +61,20 @@ add_test_vision() { if [ $# -gt 0 ]; then extra_args=$(printf " %q" "$@") fi + if [ "$USE_VIDEO" = true ]; then + arr_file+=("test-3.mp4") + else + arr_file+=("test-1.jpeg") + fi arr_prefix+=("[vision]") arr_hf+=("$hf") arr_extra_args+=("$extra_args") - arr_file+=("test-1.jpeg") } add_test_audio() { + if [ "$USE_VIDEO" = true ]; then + return 0 + fi local hf=$1 shift local extra_args="" @@ -166,19 +184,35 @@ for i in "${!arr_hf[@]}"; do cmd+=" -p \"what is the publisher name of the newspaper?\"" fi - output=$(eval "$cmd" 2>&1 | tee /dev/tty) + exit_code=0 + output=$(eval "$cmd" 2>&1 | tee /dev/tty) || exit_code=$? echo "$output" > $SCRIPT_DIR/output/$bin-$(echo "$hf" | tr '/' '-').log - # either contains "new york" or both "men" and "walk" - if echo "$output" | grep -iq "new york" \ - || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") - then - result="$prefix \033[32mOK\033[0m: $hf" + if [ "$USE_VIDEO" = true ]; then + # for video, only check exit code; do not grep for "new york" + if [ $exit_code -eq 0 ]; then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + # append generated text (after the response marker) + generated_text=$(echo "$output" | sed "1,/${MTMD_TEST_RESPONSE_MARKER}/d" | tail -10) + if [ -n "$generated_text" ]; then + result+="\n$generated_text" + fi + echo -e "$result" else - result="$prefix \033[31mFAIL\033[0m: $hf" + # either contains "new york" or both "men" and "walk" + if echo "$output" | grep -iq "new york" \ + || (echo "$output" | grep -iq "men" && echo "$output" | grep -iq "walk") + then + result="$prefix \033[32mOK\033[0m: $hf" + else + result="$prefix \033[31mFAIL\033[0m: $hf" + fi + echo -e "$result" fi - echo -e "$result" arr_res+=("$result") echo "" diff --git a/tools/mtmd/tests/test-deepseek-ocr.py b/tools/mtmd/tests/test-deepseek-ocr.py index f641045355..043a6988a4 100644 --- a/tools/mtmd/tests/test-deepseek-ocr.py +++ b/tools/mtmd/tests/test-deepseek-ocr.py @@ -9,6 +9,7 @@ its output, and holds them against the HF model's scores. import argparse import logging +import re import subprocess import sys import unicodedata @@ -28,6 +29,12 @@ class ModelSpec: mmproj_arg: str model_default: str mmproj_default: str + prompt: str = "Free OCR. " + n_predict: int = 512 + n_ctx: int | None = None + # Unlimited-OCR's "document parsing" prompt emits <|det|> grounding markup that + # the HF reference strips in result.md; drop it before scoring to match. + strip_grounding: bool = False @dataclass @@ -63,6 +70,20 @@ MODELS = { model_default="gguf_models/deepseek-ai/deepseek-ocr-2-bf16.gguf", mmproj_default="gguf_models/deepseek-ai/mmproj-deepseek-ocr-2-bf16.gguf", ), + "unlimited": ModelSpec( + key="unlimited", label="Unlimited-OCR", + model_arg="--llama-model-unlimited", mmproj_arg="--mmproj-unlimited", + model_default="gguf_models/baidu/unlimited-ocr-bf16.gguf", + mmproj_default="gguf_models/baidu/mmproj-unlimited-ocr-bf16.gguf", + # "Free OCR." immediately emits EOS on this checkpoint; the HF reference + # (demo/unlimited_ocr_scores.py) uses "document parsing.", which grounds. + prompt="document parsing.", + # Grounding emits ~3x the tokens of plain OCR, so it needs a larger budget + # and context to reach the article body the ground truth covers. + n_predict=4096, + n_ctx=16384, + strip_grounding=True, + ), } CASES = [ @@ -100,9 +121,26 @@ CASES = [ # 2 local 768 tiles + 1 global 1024 view = 545 image tokens. hf_cer=0.0236, hf_chrf=97.05, cer_tol=0.03, chrf_tol=3.0, ), + TestCase( + model_key="unlimited", label="single-view scan", + image="tools/mtmd/test-1.jpeg", + ground_truth="tools/mtmd/tests/test-1-ground-truth.txt", + # HF reference: Unlimited-OCR scoring (gundam, bf16) on this image/ground-truth. + # Decoder runs full MHA, not R-SWA; the band absorbs that gap + bf16 variance. + hf_cer=0.1869, hf_chrf=75.23, cer_tol=0.06, chrf_tol=6.0, + ), ] +GROUNDING_TAG_RE = re.compile(r"<\|(ref|det)\|>.*?<\|/\1\|>", re.DOTALL) + + +def strip_grounding(text: str) -> str: + """Drop <|ref|>..<|/ref|> / <|det|>..<|/det|> grounding markup, matching the + cleaned result.md the HF reference scores against.""" + return GROUNDING_TAG_RE.sub("", text) + + def arg_dest(flag: str) -> str: return flag.lstrip("-").replace("-", "_") @@ -147,19 +185,19 @@ def compute_chrf(expected: str, ocr_out: str) -> float: return CHRF().sentence_score(ocr_out, [expected]).score -def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str: +def run_mtmd_cli(spec: "ModelSpec", model_path, mmproj_path, image_path, bin_path) -> str: """Run mtmd-cli on the image and return its output.""" cmd = [ str(bin_path), "-m", str(model_path), "--mmproj", str(mmproj_path), "--image", str(image_path), - "-p", "Free OCR. ", + "-p", spec.prompt, "--chat-template", "deepseek-ocr", "--temp", "0", "--flash-attn", "off", # match the HF "eager" attention reference "--no-warmup", - "-n", "512", # cap loops on hard images (KV would otherwise fill) + "-n", str(spec.n_predict), # cap loops on hard images (KV would otherwise fill) # HF decodes with no_repeat_ngram_size; llama.cpp's analog is DRY. # Default DRY breakers include "\n", so they are cleared below. "--dry-multiplier", "0.8", @@ -168,6 +206,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str: "--dry-penalty-last-n", "-1", "--dry-sequence-breaker", "none", ] + if spec.n_ctx is not None: + cmd += ["-c", str(spec.n_ctx)] logger.debug(f" command: {' '.join(cmd)}") try: @@ -182,6 +222,8 @@ def run_mtmd_cli(model_path, mmproj_path, image_path, bin_path) -> str: raise RuntimeError(f"llama-mtmd-cli failed with code {result.returncode}") output = result.stdout.decode("utf-8", errors="replace").strip() + if spec.strip_grounding: + output = strip_grounding(output) if not output: raise RuntimeError("llama-mtmd-cli produced no output on stdout") logger.info(f" output: {len(output)} chars") @@ -211,7 +253,7 @@ def evaluate(case: "TestCase", expected: str, ocr_out: str) -> bool: logger.info("") logger.info("=" * 60) - logger.info("Free OCR evaluation:") + logger.info("OCR evaluation:") logger.info("=" * 60) logger.info(f" CER {cer:>7.4f} (HF {case.hf_cer:.4f}, <= {case.cer_max:>7.4f} -> {verdict(cer_pass)})") logger.info(f" chrF (0-100) {chrf:>7.2f} (HF {case.hf_chrf:.2f}, >= {case.chrf_min:>7.2f} -> {verdict(chrf_pass)})") @@ -287,9 +329,9 @@ def main() -> int: expected = read_expected_text(ground_truth) logger.info(f" Image: {case.image}") logger.info(f" Expected text: {len(expected)} chars") - logger.info(" Running llama.cpp 'Free OCR'") + logger.info(f" Running llama.cpp prompt {model_spec.prompt!r}") try: - ocr_out = run_mtmd_cli(model, mmproj, image, binary) + ocr_out = run_mtmd_cli(model_spec, model, mmproj, image, binary) except RuntimeError as e: logger.error(f" Error: {e}") results[title] = False diff --git a/tools/parser/debug-template-parser.cpp b/tools/parser/debug-template-parser.cpp index 9c591a1f11..50e8f1efb7 100644 --- a/tools/parser/debug-template-parser.cpp +++ b/tools/parser/debug-template-parser.cpp @@ -40,6 +40,7 @@ struct debug_options { bool enable_reasoning = true; bool debug_jinja = false; bool force_tool_call = false; + bool parallel_tool_calls = true; output_mode mode = output_mode::BOTH; input_message_type input_message = input_message_type::NONE; }; @@ -87,6 +88,7 @@ static void print_usage(const char * program_name) { LOG_ERR("\nOptions:\n"); LOG_ERR(" --no-tools Disable tool definitions\n"); LOG_ERR(" --force-tool-call Set tool calls to forced\n"); + LOG_ERR(" --parallel-tool-calls=0|1 Set parallel_tool_calls (default: 1)\n"); LOG_ERR(" --generation-prompt=0|1 Set add_generation_prompt (default: 1)\n"); LOG_ERR(" --enable-reasoning=0|1 Enable reasoning parsing (default: 1)\n"); LOG_ERR(" --output=MODE Output mode: analysis, template, both (default: both)\n"); @@ -121,6 +123,8 @@ static bool parse_options(int argc, char ** argv, debug_options & opts) { opts.debug_jinja = true; } else if (arg == "--no-tools") { opts.with_tools = false; + } else if (arg.rfind("--parallel-tool-calls=", 0) == 0) { + opts.parallel_tool_calls = parse_bool_option(arg.substr(22)); } else if (arg.rfind("--generation-prompt=", 0) == 0) { opts.generation_prompt = parse_bool_option(arg.substr(20)); } else if (arg.rfind("--enable-reasoning=", 0) == 0) { @@ -349,7 +353,7 @@ static autoparser::generation_params prepare_params(const debug_options & opts, params.tools = json(); params.tool_choice = COMMON_CHAT_TOOL_CHOICE_NONE; } - params.parallel_tool_calls = false; + params.parallel_tool_calls = opts.parallel_tool_calls; return params; } diff --git a/tools/server/README-dev.md b/tools/server/README-dev.md index 4c41031239..5959745e47 100644 --- a/tools/server/README-dev.md +++ b/tools/server/README-dev.md @@ -180,6 +180,17 @@ That requires `JSON.stringify` when formatted to message content: } ``` +### Router mode: how child <--> router communicates + +Upon spawning a new child process using `subprocess`, both child and router listen to the stdout/stderr (combined) + +For the direction from child to router: +- Generic messages are logs, it will be forwarded to router's stdout +- Special state update messages are prefixed by `cmd_child_to_router:state:`, followed by a JSON. See `server_models::handle_child_state` for more + +For the direction from router to child: +- When server sends `cmd_router_to_child:exit`, the child should exit gracefully --> if after `DEFAULT_STOP_TIMEOUT` and the child is still running, force-kill it + ### Model management API (router mode) Model management API was added via PR [#23976](https://github.com/ggml-org/llama.cpp/pull/23976) @@ -193,9 +204,9 @@ Instead of building everything from the ground up (like what most AI agents will The flow for downloading a new model: - POST request comes in --> `post_router_models` --> validation -- `server_models::download()` is called - - Sets up a new thread `inst.th` and runs the download inside -- If a stop request comes in, set `stop_download` to `true` +- A new `llama-server` subprocess will be spawned with special `SERVER_CHILD_MODE_DOWNLOAD` +- Child process runs the download and report status back to router via stdin/out +- If a stop request comes in, the router asks the child process to stop (same mechanism as running a model in child process) - Otherwise, upon completion, we call `load_models()` to refresh the list of models ### Notable Related PRs diff --git a/tools/server/README.md b/tools/server/README.md index 88a507e2c5..e88bc5f28a 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -175,13 +175,12 @@ For the full list of features, please refer to [server's changelog](https://gith | `-np, --parallel N` | number of server slots (default: -1, -1 = auto)
(env: LLAMA_ARG_N_PARALLEL) | | `-cb, --cont-batching, -nocb, --no-cont-batching` | whether to enable continuous batching (a.k.a dynamic batching) (default: enabled)
(env: LLAMA_ARG_CONT_BATCHING) | | `-mm, --mmproj FILE` | path to a multimodal projector file. see tools/mtmd/README.md
note: if -hf is used, this argument can be omitted
(env: LLAMA_ARG_MMPROJ) | -| `-tk, --talker-model FILE` | path to the qwen3-omni talker gguf, enables the /v1/audio/speech endpoint
(env: LLAMA_ARG_TALKER_MODEL) | -| `-c2w, --code2wav-model FILE` | path to the qwen3-omni code2wav gguf, the talker code detokenizer
(env: LLAMA_ARG_CODE2WAV_MODEL) | | `-mmu, --mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md
(env: LLAMA_ARG_MMPROJ_URL) | | `--mmproj-auto, --no-mmproj, --no-mmproj-auto` | whether to use multimodal projector file (if available), useful when using -hf (default: enabled)
(env: LLAMA_ARG_MMPROJ_AUTO) | | `--mmproj-offload, --no-mmproj-offload` | whether to enable GPU offloading for multimodal projector (default: enabled)
(env: LLAMA_ARG_MMPROJ_OFFLOAD) | | `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MIN_TOKENS) | | `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)
(env: LLAMA_ARG_IMAGE_MAX_TOKENS) | +| `--mtmd-batch-max-tokens N` | maximum number of image tokens per batch when encoding images (default: 1024)
(env: LLAMA_ARG_MTMD_BATCH_MAX_TOKENS) | | `-a, --alias STRING` | set model name aliases, comma-separated (to be used by API)
(env: LLAMA_ARG_ALIAS) | | `--tags STRING` | set model tags, comma-separated (informational, not used for routing)
(env: LLAMA_ARG_TAGS) | | `--embd-normalize N` | normalisation for embeddings (default: 2) (-1=none, 0=max absolute int16, 1=taxicab, 2=euclidean, >2=p-norm) | @@ -190,23 +189,21 @@ For the full list of features, please refer to [server's changelog](https://gith | `--reuse-port` | allow multiple sockets to bind to the same port (default: disabled)
(env: LLAMA_ARG_REUSE_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | | `--api-prefix PREFIX` | prefix path the server serves from, without the trailing slash (default: )
(env: LLAMA_ARG_API_PREFIX) | -| `--webui-config JSON` | [DEPRECATED: use --ui-config] JSON that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG) | -| `--ui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | -| `--webui-config-file PATH` | [DEPRECATED: use --ui-config-file] JSON file that provides default WebUI settings (overrides WebUI defaults)
(env: LLAMA_ARG_WEBUI_CONFIG_FILE) | -| `--ui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | -| `--webui-mcp-proxy, --no-webui-mcp-proxy` | [DEPRECATED: use --ui-mcp-proxy/--no-ui-mcp-proxy] experimental: whether to enable MCP CORS proxy
(env: LLAMA_ARG_WEBUI_MCP_PROXY) | -| `--ui-mcp-proxy, --no-ui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | +| `--ui-config, --webui-config JSON` | JSON that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG) | +| `--ui-config-file, --webui-config-file PATH` | JSON file that provides default UI settings (overrides UI defaults)
(env: LLAMA_ARG_UI_CONFIG_FILE) | +| `--ui-mcp-proxy, --webui-mcp-proxy, --no-ui-mcp-proxy, --no-webui-mcp-proxy` | experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_UI_MCP_PROXY) | | `--tools TOOL1,TOOL2,...` | experimental: whether to enable built-in tools for AI agents - do not enable in untrusted environments (default: no tools)
specify "all" to enable all tools
available tools: read_file, file_glob_search, grep_search, exec_shell_command, write_file, edit_file, apply_diff, get_datetime
(env: LLAMA_ARG_TOOLS) | -| `--webui, --no-webui` | [DEPRECATED: use --ui/--no-ui] whether to enable the Web UI
(env: LLAMA_ARG_WEBUI) | -| `--ui, --no-ui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | +| `-ag, --agent, -no-ag, --no-agent` | whether to enable CORS proxy and all built-in tools - do not enable in untrusted environments (default: disabled)
(env: LLAMA_ARG_AGENT) | +| `--ui, --webui, --no-ui, --no-webui` | whether to enable the Web UI (default: enabled)
(env: LLAMA_ARG_UI) | | `--embedding, --embeddings` | restrict to only support embedding use case; use only with dedicated embedding models (default: disabled)
(env: LLAMA_ARG_EMBEDDINGS) | | `--rerank, --reranking` | enable reranking endpoint on server (default: disabled)
(env: LLAMA_ARG_RERANKING) | | `--api-key KEY` | API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)
(env: LLAMA_API_KEY) | -| `--api-key-file FNAME` | path to file containing API keys (default: none)
(env: LLAMA_ARG_API_KEY_FILE) | +| `--api-key-file FNAME` | path to file containing API keys, one per line; lines starting with a hash are treated as comments (default: none)
(env: LLAMA_ARG_API_KEY_FILE) | | `--ssl-key-file FNAME` | path to file a PEM-encoded SSL private key
(env: LLAMA_ARG_SSL_KEY_FILE) | | `--ssl-cert-file FNAME` | path to file a PEM-encoded SSL certificate
(env: LLAMA_ARG_SSL_CERT_FILE) | | `--chat-template-kwargs STRING` | sets additional params for the json template parser, must be a valid json object string, e.g. '{"key1":"value1","key2":"value2"}'
(env: LLAMA_ARG_CHAT_TEMPLATE_KWARGS) | | `-to, --timeout N` | server read/write timeout in seconds (default: 3600)
(env: LLAMA_ARG_TIMEOUT) | +| `--sse-ping-interval N` | server SSE ping interval in seconds (-1 = disabled, default: 30)
(env: LLAMA_ARG_SSE_PING_INTERVAL) | | `--threads-http N` | number of threads used to process HTTP requests (default: -1)
(env: LLAMA_ARG_THREADS_HTTP) | | `--cache-prompt, --no-cache-prompt` | whether to enable prompt caching (default: enabled)
(env: LLAMA_ARG_CACHE_PROMPT) | | `--cache-reuse N` | min chunk size to attempt reusing from the cache via KV shifting, requires prompt caching to be enabled (default: 0)
[(card)](https://ggml.ai/f0.png)
(env: LLAMA_ARG_CACHE_REUSE) | @@ -231,6 +228,7 @@ For the full list of features, please refer to [server's changelog](https://gith | `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled) | | `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) | | `--sleep-idle-seconds SECONDS` | number of seconds of idleness after which the server will sleep (default: -1; -1 = disabled) | +| `--log-prompts-dir PATH` | Log prompts to directory (only used for debugging, default: disabled) | | `--spec-draft-hf, -hfd, -hfrd, --hf-repo-draft /[:quant]` | Same as --hf-repo, but for the draft model (default: unused)
(env: LLAMA_ARG_SPEC_DRAFT_HF_REPO) | | `--spec-draft-threads, -td, --threads-draft N` | number of threads to use during generation (default: same as --threads) | | `--spec-draft-threads-batch, -tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) | @@ -1232,8 +1230,6 @@ print(completion.choices[0].text) Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only models with a [supported chat template](https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template) can be used optimally with this endpoint. By default, the ChatML template will be used. -If model supports multimodal, you can input the media file via `image_url` content part. We support both base64 and remote URL as input. See OAI documentation for more. - *Options:* See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). llama.cpp `/completion`-specific features such as `mirostat` are also supported. @@ -1252,9 +1248,18 @@ The `response_format` parameter supports both plain JSON output (e.g. `{"type": `parallel_tool_calls` : Whether to enable parallel/multiple tool calls (only supported on some models, verification is based on jinja template). -For multimodal input: -- Content type `image_url` and `input_audio` are the same as OAI schema -- Content type `input_video` is an extension from OAI schema. For now, it only accepts base64 input +For multimodal input (typed content, `messages[i].content[j]`): +- If `type == "image_url"`: + - `image_url.url` can be a remote URL, base64 (raw or URI-encoded via `data:image/...;base64`) or path to local file + - Accepts formats supported by `stb_image` (jpeg, png, tga, bmp, gif, ...) +- If `type == "input_audio"`: + - Either `input_audio.data` or `input_audio.url` can be specified, can be a remote URL, raw base64 or path to local file + - Accepts formats supported by `miniaudio` (mp3, wav, flac) + - `input_audio.format` will be ignored, the file format will be determined automatically +- If `type == "input_video"`: + - Either `input_video.data` or `input_video.url` can be specified, can be a remote URL, raw base64 or path to local file + - Accepts formats supported by `ffmpeg` +- Note: for local file, make sure to set `--media-path`. File path must be prefixed by `file://` *Examples:* @@ -1861,9 +1866,37 @@ Example events: { "model": "...", - "event": "download_finished", + "event": "model_status", "data": { - "status": "loading" + "status": "loading", + "progress": { + "stages": ["text_model", "spec_model", "mmproj_model"], + "current": "text_model", + "value": 0.5 + } + } +} +// note for "loading" status: +// - subsequent events will follow the same order of "stages" list +// - mmap is may report incorrect progress on some platforms; if you need exact progress, use --no-mmap + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "loaded", + "info": { + // note: only include info on first load + // waking up from sleep doesn't have this + } + } +} + +{ + "model": "...", + "event": "model_status", + "data": { + "status": "sleeping" } } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index 75729e62dd..ac291d359a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -12,6 +12,7 @@ #include #include #include +#include json format_error_response(const std::string & message, const enum error_type type) { std::string type_str; @@ -517,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const { return max_idx; // all tokens are equal } +common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const { + std::map skips; + for (const auto & it : map_idx_to_media) { + skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get()); + } + return delims.split(tokens, skips); +} + bool server_tokens::validate(const struct llama_context * ctx) const { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -816,12 +825,21 @@ json oaicompat_completion_params_parse(const json & body) { return llama_params; } -// media_path always end with '/', see arg.cpp +// url can be +// - http(s):// for remote files +// - file:// for local files (only allowed if media_path is set) +// - data: for base64 encoded data with uri scheme (e.g. data:image/png;base64,...) +// - raw base64 encoded data static void handle_media( std::vector & out_files, - json & media_obj, - const std::string & media_path) { - std::string url = json_value(media_obj, "url", std::string()); + const std::string & url, + const std::string & media_path, + bool accept_base64_uri) { + if (!media_path.empty()) { + // should already be enforced by arg.cpp, but checking just in case + GGML_ASSERT(media_path.back() == DIRECTORY_SEPARATOR); + } + if (string_starts_with(url, "http")) { // download remote image // TODO @ngxson : maybe make these params configurable @@ -857,20 +875,28 @@ static void handle_media( data.assign((std::istreambuf_iterator(file)), std::istreambuf_iterator()); out_files.push_back(data); - } else { + } else if (accept_base64_uri && string_starts_with(url, "data:")) { // try to decode base64 image std::vector parts = string_split(url, /*separator*/ ','); if (parts.size() != 2) { - throw std::runtime_error("Invalid url value"); + throw std::runtime_error("Invalid uri-encoded base64 value"); } else if (!string_starts_with(parts[0], "data:image/")) { - throw std::runtime_error("Invalid url format: " + parts[0]); + throw std::runtime_error("Invalid uri format: " + parts[0]); } else if (!string_ends_with(parts[0], "base64")) { - throw std::runtime_error("url must be base64 encoded"); + throw std::runtime_error("uri must be base64 encoded"); } else { auto base64_data = parts[1]; auto decoded_data = base64_decode(base64_data); out_files.push_back(decoded_data); } + + } else { + // try as raw base64 string + auto decoded_data = base64_decode(url); + if (decoded_data.empty()) { + throw std::runtime_error("Invalid base64 value"); + } + out_files.push_back(decoded_data); } } @@ -956,14 +982,15 @@ json oaicompat_chat_params_parse( } for (auto & p : content) { - std::string type = json_value(p, "type", std::string()); + std::string type = json_value(p, "type", std::string()); if (type == "image_url") { if (!opt.allow_image) { throw std::runtime_error("image input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } json image_url = json_value(p, "image_url", json::object()); - handle_media(out_files, image_url, opt.media_path); + std::string url = json_value(image_url, "url", std::string()); + handle_media(out_files, url, opt.media_path, true); p["type"] = "media_marker"; p["text"] = get_media_marker(); @@ -974,17 +1001,11 @@ json oaicompat_chat_params_parse( throw std::runtime_error("audio input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json input_audio = json_value(p, "input_audio", json::object()); - std::string data = json_value(input_audio, "data", std::string()); - std::string format = json_value(input_audio, "format", std::string()); - // while we also support flac, we don't allow it here so we matches the OAI spec - if (format != "wav" && format != "mp3") { - throw std::invalid_argument("input_audio.format must be either 'wav' or 'mp3'"); - } - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); - - // TODO: add audio_url support by reusing handle_media() + // note: don't need to validate "format", it's redundant + json input_audio = json_value(p, "input_audio", json::object()); + std::string url = json_value(input_audio, "data", + json_value(input_audio, "url", std::string())); + handle_media(out_files, url, opt.media_path, false); p["type"] = "media_marker"; p["text"] = get_media_marker(); @@ -995,10 +1016,10 @@ json oaicompat_chat_params_parse( throw std::runtime_error("video input is not supported - hint: if this is unexpected, you may need to provide the mmproj"); } - json input_video = json_value(p, "input_video", json::object()); - std::string data = json_value(input_video, "data", std::string()); - auto decoded_data = base64_decode(data); // expected to be base64 encoded - out_files.push_back(decoded_data); + json input_video = json_value(p, "input_video", json::object()); + std::string url = json_value(input_video, "data", + json_value(input_video, "url", std::string())); + handle_media(out_files, url, opt.media_path, false); p["type"] = "media_marker"; p["text"] = get_media_marker(); @@ -1091,15 +1112,7 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } - llama_params["message_spans"] = json::array(); - - for (const auto & span : chat_params.message_spans) { - llama_params["message_spans"].push_back({ - { "role", span.role }, - { "pos", span.pos }, - { "len", span.len }, - }); - } + llama_params["message_delimiters"] = chat_params.message_delimiters.to_json(); // Reasoning budget: pass parameters through to sampling layer { @@ -1238,7 +1251,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx) { +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top) { std::vector cur; const auto * logits = llama_get_logits_ith(ctx, idx); @@ -1257,21 +1270,34 @@ std::vector get_token_probabilities(llama_context * ctx, int i } } - // sort tokens by logits - std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) { - return a.logit > b.logit; - }); + // sort tokens by logits (partial: only the leading `n_top` need ordering) + if (n_top > cur.size()) { + n_top = cur.size(); + } + if (n_top > 0) { + std::partial_sort(cur.begin(), cur.begin() + n_top, cur.end(), + [](const llama_token_data & a, const llama_token_data & b) { + return a.logit > b.logit; + }); + } // apply softmax - float max_l = cur[0].logit; + float max_l = -std::numeric_limits::infinity(); + if (n_top > 0) { + max_l = cur[0].logit; // partial_sort guarantees the absolute maximum is at index 0 + } else { + for (const auto & t : cur) { + max_l = std::max(max_l, t.logit); + } + } float cum_sum = 0.0f; - for (size_t i = 0; i < cur.size(); ++i) { - float p = expf(cur[i].logit - max_l); - cur[i].p = p; + for (auto & t : cur) { + float p = expf(t.logit - max_l); + t.p = p; cum_sum += p; } - for (size_t i = 0; i < cur.size(); ++i) { - cur[i].p /= cum_sum; + for (auto & t : cur) { + t.p /= cum_sum; } return cur; diff --git a/tools/server/server-common.h b/tools/server/server-common.h index f286b3d156..c0eaec6b02 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -218,6 +218,9 @@ public: size_t get_common_prefix(const server_tokens & b) const; + // split the tokens into message spans, skipping over media chunks + common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const; + // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; @@ -326,7 +329,7 @@ json format_response_rerank( // other utils // -std::vector get_token_probabilities(llama_context * ctx, int idx); +std::vector get_token_probabilities(llama_context * ctx, int idx, size_t n_top); std::string safe_json_to_str(const json & data); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index aebca306a8..39b7eb218e 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -63,9 +63,99 @@ enum slot_state { SLOT_STATE_GENERATING, }; -enum server_state { - SERVER_STATE_LOADING_MODEL, // Server is starting up, model not fully loaded yet - SERVER_STATE_READY, // Server is ready and model is loaded +struct server_slot; // forward declaration + +struct server_batch { + llama_batch batch; + bool batch_rendered = false; + + struct token { + int32_t id_slot; + llama_token token; + llama_pos pos; + bool output; + }; + std::vector tokens; + int32_t n_tokens_alloc = 0; + + // track if given slot can be batched with slots already in the batch + server_slot * slot_batched = nullptr; + + float alora_scale = -1.0f; + size_t alora_disabled_id = 0; + + server_batch() { + batch.token = nullptr; // sentinel: uninitialized batch + } + + ~server_batch() { + if (batch.token != nullptr) { + llama_batch_free(batch); + } + } + + void init(int32_t n_tokens_alloc) { + this->n_tokens_alloc = n_tokens_alloc; + batch = llama_batch_init(n_tokens_alloc, 0, 1); + tokens.reserve(n_tokens_alloc); + } + + bool add(int32_t id_slot, llama_token token, llama_pos pos, bool output) { + GGML_ASSERT(batch.token != nullptr); + if ((int32_t)tokens.size() >= n_tokens_alloc) { + return false; + } + // LOG_INF("adding token to batch: slot=%d, token=%d, pos=%d, output=%d\n", id_slot, token, pos, output); + tokens.push_back({ id_slot, token, pos, output }); + return true; + } + + void clear() { + tokens.clear(); + common_batch_clear(batch); + slot_batched = nullptr; + alora_scale = -1.0f; + alora_disabled_id = 0; + batch_rendered = false; + } + + int32_t size() const { + return (int32_t)tokens.size(); + } + + void set_output(int32_t idx, bool output) { + GGML_ASSERT(idx >= 0 && idx < (int32_t)tokens.size()); + tokens[idx].output = output; + } + + void render() { + GGML_ASSERT(batch.token != nullptr); + common_batch_clear(batch); + for (int32_t i = 0; i < size(); i++) { + const auto & t = tokens[i]; + common_batch_add(batch, t.token, t.pos, { t.id_slot }, t.output); + } + batch_rendered = true; + } + + llama_batch get_view(int32_t off, int32_t n_tokens) const { + GGML_ASSERT(batch.token != nullptr); + GGML_ASSERT(batch_rendered); + GGML_ASSERT(off >= 0 && off < size()); + GGML_ASSERT(n_tokens > 0 && off + n_tokens <= size()); + + llama_batch view = { + n_tokens, + batch.token + off, + nullptr, + batch.pos + off, + batch.n_seq_id + off, + batch.seq_id + off, + batch.logits + off, + }; + + return view; + } }; struct server_slot { @@ -190,6 +280,7 @@ struct server_slot { // stats size_t n_sent_text = 0; // number of sent text character + // TODO @ngxson : move all metrics to a sub-struct for clarity int64_t t_start_process_prompt; int64_t t_start_generation; int64_t t_print_last = 0; @@ -353,12 +444,14 @@ struct server_slot { return n_draft_max; } - void update_batch(llama_batch & batch) { + // add sampled token of this slot to the batch, optionally add the speculative draft tokens if any + void handle_last_sampled_token(server_batch & batch) { + bool add_ok = true; if (spec_draft.empty()) { // no speculative decoding - i_batch = batch.n_tokens; + i_batch = batch.size(); - common_batch_add(batch, sampled, prompt.tokens.pos_next(), { this->id }, true); + add_ok &= batch.add(id, sampled, prompt.tokens.pos_next(), true); SLT_DBG(*this, "slot decode token, id=%d, n_ctx = %d, n_tokens = %d, truncated = %d\n", sampled, n_ctx, prompt.n_tokens(), truncated); @@ -368,19 +461,21 @@ struct server_slot { GGML_ASSERT(spec_i_batch.empty()); - spec_i_batch.push_back(batch.n_tokens); + spec_i_batch.push_back(batch.size()); for (size_t i = 0; i < spec_draft.size(); i++) { - spec_i_batch.push_back(batch.n_tokens + i + 1); + spec_i_batch.push_back(batch.size() + i + 1); } auto pos0 = prompt.tokens.pos_next(); - common_batch_add(batch, sampled, pos0++, { this->id }, true); + add_ok &= batch.add(id, sampled, pos0++, true); for (auto token : spec_draft) { - common_batch_add(batch, token, pos0++, { this->id }, true); + add_ok &= batch.add(this->id, token, pos0++, true); } } + GGML_ASSERT(add_ok && "batch must be large enough to hold the sampled and draft tokens"); + prompt.tokens.push_back(sampled); prompt.tokens.insert(spec_draft); } @@ -773,6 +868,8 @@ public: // note: chat_params must not be refreshed upon existing sleeping state server_chat_params chat_params; + server_state_callback_t callback_state = [](server_state, json) -> void {}; + server_context_impl() { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -796,7 +893,7 @@ private: llama_context * ctx_tgt = nullptr; - llama_batch batch {}; + server_batch batch; llama_model_ptr model_dft; llama_context_ptr ctx_dft; @@ -825,8 +922,7 @@ private: server_metrics metrics; - json json_ui_settings = json::object(); // Primary: new name - json json_webui_settings = json::object(); // Deprecated: use json_ui_settings instead (kept for compat) + json json_ui_settings = json::object(); // Necessary similarity of prompt for slot selection float slot_prompt_similarity = 0.0f; @@ -837,6 +933,8 @@ private: bool sleeping = false; + int64_t t_last_load_progress_ms = 0; + void destroy() { spec.reset(); ctx_dft.reset(); @@ -849,8 +947,6 @@ private: mtmd_free(mctx); mctx = nullptr; - - llama_batch_free(batch); } void handle_sleeping_state(bool new_state) { @@ -867,18 +963,77 @@ private: sleeping = new_state; } + struct load_progress_data { + server_context_impl * ctx; + std::string stage; + std::vector stages; + int64_t t_last_load_progress_ms = 0; + load_progress_data(server_context_impl * ctx, const std::string & stage) : ctx(ctx), stage(stage) {} + }; + static bool load_progress_callback(float progress, void * user_data) { + auto * d = static_cast(user_data); + GGML_ASSERT(d); + // always emit the first and final sample; throttle the rest to one per 200ms + { + auto & t_last = d->t_last_load_progress_ms; + const int64_t t_now = ggml_time_ms(); + const bool first = t_last == 0; + const bool done = progress >= 1.0f; + const bool throttled = !first && !done && (t_now - t_last) < 200; + if (throttled) { + return true; + } + t_last = t_now; + } + if (d->ctx->callback_state) { + d->ctx->callback_state(SERVER_STATE_LOADING, { + {"stages", d->stages}, + {"current", d->stage}, + {"value", progress}, + }); + } + return true; + } + // load the model and initialize llama_context // this may also be called to resume from sleeping state bool load_model(common_params & params) { - bool is_resume = sleeping; + load_progress_data load_progress_text (this, "text_model"); + load_progress_data load_progress_mmproj(this, "mmproj_model"); + load_progress_data load_progress_spec (this, "spec_model"); - SRV_INF("loading model '%s'\n", params.model.path.c_str()); + const bool is_resume = sleeping; params_base = params; params_base.n_outputs_max = server_n_outputs_max(params_base); + const bool has_mmproj = !params.mmproj.path.empty(); + const bool has_draft = params.speculative.has_dft(); + const bool spec_mtp = std::find(params_base.speculative.types.begin(), + params_base.speculative.types.end(), + COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); + const bool has_spec = has_draft || spec_mtp; + + if (callback_state) { + std::vector stages = {"text_model"}; + if (has_spec) { + stages.push_back("spec_model"); + } + if (has_mmproj) { + stages.push_back("mmproj_model"); + } + load_progress_text.stages = stages; + load_progress_mmproj.stages = stages; + load_progress_spec.stages = stages; + + // trigger 0% progress + load_progress_callback(0.0f, &load_progress_text); + } + + + SRV_INF("loading model '%s'\n", params.model.path.c_str()); + std::string & mmproj_path = params_base.mmproj.path; - bool has_mmproj = !mmproj_path.empty(); mtmd_context_params mparams = mtmd_context_params_default(); if (has_mmproj) { mparams.use_gpu = params_base.mmproj_use_gpu; @@ -890,17 +1045,22 @@ private: mparams.image_max_tokens = params_base.image_max_tokens; mparams.batch_max_tokens = params_base.mtmd_batch_max_tokens; mparams.media_marker = get_media_marker(); + // progress callback + mparams.progress_callback = load_progress_callback; + mparams.progress_callback_user_data = &load_progress_mmproj; } // optionally get the memory usage of mmproj if (has_mmproj && params_base.fit_params) { + int64_t t_start = ggml_time_us(); auto mmproj_mem = mtmd_get_memory_usage(mmproj_path.c_str(), mparams); + int64_t t_elapsed = ggml_time_us() - t_start; if (!mmproj_mem.empty()) { size_t total = 0; for (auto & [dev, size] : mmproj_mem) { total += size; } - SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB\n", total / (1024.0 * 1024.0)); + SRV_INF("[mtmd] estimated worst-case memory usage of mmproj is %.2f MiB (took %.2f ms)\n", total / (1024.0 * 1024.0), t_elapsed / 1000.0); GGML_ASSERT(!params_base.fit_params_target.empty()); for (auto & [dev, size] : mmproj_mem) { for (size_t i = 0; i < ggml_backend_dev_count(); i++) { @@ -920,12 +1080,7 @@ private: // optionally reserve VRAM for the draft / MTP context before fitting the target model if (params_base.fit_params) { - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - const bool has_draft = params_base.speculative.has_dft(); - - if (has_draft || spec_mtp) { + if (has_spec) { common_params params_dft = params_base; bool measure_model_bytes = true; @@ -995,6 +1150,12 @@ private: } } + // attach a progress callback + { + params_base.load_progress_callback = load_progress_callback; + params_base.load_progress_callback_user_data = &load_progress_text; + } + llama_init = common_init_from_params(params_base); model_tgt = llama_init->model(); @@ -1011,7 +1172,7 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); - if (params_base.speculative.has_dft()) { + if (has_draft) { // TODO speculative: move to common/speculative.cpp? const auto & params_spec = params_base.speculative.draft; @@ -1034,6 +1195,10 @@ private: auto mparams_dft = common_model_params_to_llama(params_dft); + // progress callback + mparams_dft.progress_callback = load_progress_callback; + mparams_dft.progress_callback_user_data = &load_progress_spec; + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); @@ -1042,10 +1207,6 @@ private: auto cparams = common_context_params_to_llama(params_dft); - const bool spec_mtp = std::find(params_base.speculative.types.begin(), - params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end(); - if (spec_mtp) { cparams.ctx_type = LLAMA_CONTEXT_TYPE_MTP; } @@ -1056,11 +1217,17 @@ private: cparams.ctx_other = ctx_tgt; ctx_dft.reset(llama_init_from_model(model_dft.get(), cparams)); + if (ctx_dft == nullptr) { + SRV_ERR("%s", "failed to create draft context\n"); + return false; + } params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); - } else if (std::find(params_base.speculative.types.begin(), params_base.speculative.types.end(), - COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params_base.speculative.types.end()) { + } else if (spec_mtp) { + // no new model load, so we simply report 0.0 and 1.0 progress + load_progress_callback(0.0f, &load_progress_spec); + SRV_INF("creating MTP draft context against the target model '%s'\n", params_base.model.path.c_str()); @@ -1080,9 +1247,15 @@ private: params_base.speculative.draft.ctx_tgt = ctx_tgt; params_base.speculative.draft.ctx_dft = ctx_dft.get(); + + load_progress_callback(1.0f, &load_progress_spec); } if (has_mmproj) { + if (callback_state) { + callback_state(SERVER_STATE_LOADING, {{"stage", "mmproj_model"}}); + } + if (!is_resume) { mtmd_helper_log_set(common_log_default_callback, nullptr); } @@ -1218,7 +1391,7 @@ private: // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) { const int32_t n_batch = llama_n_batch(ctx_tgt); - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); + batch.init(std::max(n_batch, params_base.n_parallel)); } if (params_base.cache_ram_mib != 0) { @@ -1245,8 +1418,8 @@ private: if (!params_base.model_alias.empty()) { // backward compat: use first alias as model name model_name = *params_base.model_alias.begin(); - } else if (!params_base.model.name.empty()) { - model_name = params_base.model.name; + } else if (!params_base.model.get_name().empty()) { + model_name = params_base.model.get_name(); } else { // fallback: derive model name from file name auto model_path = std::filesystem::path(params_base.model.path); @@ -1263,6 +1436,10 @@ private: return init(); } + if (callback_state) { + callback_state(SERVER_STATE_READY, {}); + } + return true; } @@ -1302,16 +1479,12 @@ private: } } - // populate UI settings (from either new ui_config_json or deprecated webui_config_json) { - const std::string & cfg = !params_base.ui_config_json.empty() - ? params_base.ui_config_json - : params_base.webui_config_json; + const std::string & cfg = params_base.ui_config_json; if (!cfg.empty()) { try { json json_settings = json::parse(cfg); json_ui_settings = json_settings; - json_webui_settings = json_settings; // deprecated: keep in sync } catch (const std::exception & e) { SRV_ERR("%s: failed to parse UI config: %s\n", __func__, e.what()); return false; @@ -1343,6 +1516,9 @@ private: const bool enable_thinking = params_base.enable_reasoning != 0 && template_supports_thinking; SRV_INF("%s: chat template, thinking = %d\n", __func__, enable_thinking); + // IMPORTANT: chat_params is reused across sleeping / resuming states, + // never store llama_context/llama_model pointers in chat_params, + // as they may be invalidated after sleeping chat_params = { /* use_jinja */ params_base.use_jinja, /* prefill_assistant */ params_base.prefill_assistant, @@ -1395,11 +1571,23 @@ private: bool update_cache = false; + // if a specific slot is requested, use it (still goes through cache update logic below) + if (task.id_slot != -1) { + ret = get_slot_by_id(task.id_slot); + if (ret) { + SLT_INF(*ret, "selected slot by id (%d)\n", task.id_slot); + } + } + // find the slot that has at least n% prompt similarity - if (ret == nullptr && slot_prompt_similarity != 0.0f) { + if (slot_prompt_similarity != 0.0f) { float sim_best = 0; for (server_slot & slot : slots) { + if (task.id_slot != -1 && slot.id != task.id_slot) { + continue; + } + // skip the slot if it is not available if (slot.is_processing()) { continue; @@ -1426,8 +1614,10 @@ private: if (ret != nullptr) { const float f_keep = (sim_best*task.tokens.size()) / ret->prompt.tokens.size(); - SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", - sim_best, slot_prompt_similarity, f_keep); + if (task.id_slot == -1) { + SLT_INF(*ret, "selected slot by LCP similarity, sim_best = %.3f (> %.3f thold), f_keep = %.3f\n", + sim_best, slot_prompt_similarity, f_keep); + } // if we are about to lose a large portion of the existing context - save it in the prompt cache if (f_keep < 0.5f) { @@ -1815,8 +2005,7 @@ private: }); } } else { - // TODO: optimize this with min-p optimization - std::vector cur = get_token_probabilities(ctx_tgt, idx); + std::vector cur = get_token_probabilities(ctx_tgt, idx, n_probs_request); const size_t max_probs = cur.size(); const size_t n_probs = std::min(max_probs, n_probs_request); @@ -2158,6 +2347,8 @@ private: cur.update_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); cur.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // stash the draft's speculative state with the checkpoint + common_speculative_get_state(spec.get(), slot.id, cur.data_spec); SLT_INF(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", @@ -2180,10 +2371,9 @@ private: } } - const int id_slot = task.id_slot; const int id_task = task.id; - server_slot * slot = id_slot != -1 ? get_slot_by_id(id_slot) : get_available_slot(task); + server_slot * slot = get_available_slot(task); // // slot scheduling logic @@ -2491,7 +2681,83 @@ private: } } + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(slot); + } catch (const std::exception & e) { + SLT_ERR(slot, "got exception: %s\n", e.what()); + send_error(slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + void iterate(std::vector & slots, std::function callback) { + for (auto & slot : slots) { + try { + callback(*slot); + } catch (const std::exception & e) { + SLT_ERR(*slot, "got exception: %s\n", e.what()); + send_error(*slot, std::string("got exception: ") + e.what(), ERROR_TYPE_SERVER); + slot->release(); + } + } + } + + void abort_all_slots(const std::string & reason) { + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, reason, ERROR_TYPE_SERVER); + slot.release(); + } + } + } + + // @ngxson : for debugging only + int64_t t_pre_decode = 0; + int64_t t_decode = 0; + int64_t t_post_decode = 0; + int64_t t_sampl = 0; + int64_t n_pre_decode = 0; + int64_t n_decode = 0; + int64_t n_post_decode = 0; + int64_t n_sampl = 0; +// #define DEBUG_TIMINGS +#ifdef DEBUG_TIMINGS + struct scoped_timer { + int64_t & t; + int64_t & n; + int64_t t_start; + scoped_timer(int64_t & t_, int64_t & n_) : t(t_), n(n_) { + t_start = ggml_time_us(); + } + ~scoped_timer() { + t += ggml_time_us() - t_start; + n++; + } + }; +#else + struct scoped_timer { + scoped_timer(int64_t &, int64_t &) {} + ~scoped_timer() {} + }; +#endif + void update_slots() { +#ifdef DEBUG_TIMINGS + static int64_t t_prev = 0; + int64_t t_start = ggml_time_us(); + if (t_start - t_prev > 5 * 1000 * 1000) { // every 5 seconds + t_prev = t_start; + SRV_INF("n_pre_decode = %" PRId64 "\n", n_pre_decode); + SRV_INF("avg t_pre_decode = %f ms\n", (double) t_pre_decode / n_pre_decode / 1000.0); + SRV_INF("avg t_decode = %f ms\n", (double) t_decode / n_decode / 1000.0); + SRV_INF("avg t_post_decode = %f ms\n", (double) t_post_decode / n_post_decode / 1000.0); + SRV_INF("avg t_sampl = %f ms\n", (double) t_sampl / n_sampl / 1000.0); + } +#endif + // check if all slots are idle { bool all_idle = true; @@ -2505,29 +2771,80 @@ private: if (all_idle) { SRV_INF("%s", "all slots are idle\n"); + return; // skip further processing - return; + } else { + SRV_DBG("%s", "posting NEXT_RESPONSE\n"); + + server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); + task.id = queue_tasks.get_new_id(); + queue_tasks.post(std::move(task)); } } - { - SRV_DBG("%s", "posting NEXT_RESPONSE\n"); - - server_task task(SERVER_TASK_TYPE_NEXT_RESPONSE); - task.id = queue_tasks.get_new_id(); - queue_tasks.post(std::move(task)); + try { + scoped_timer t(t_pre_decode, n_pre_decode); + pre_decode(); + batch.render(); + } catch (const std::exception & e) { + SRV_ERR("pre_decode() failed: %s\n", e.what()); + abort_all_slots("pre_decode() failed: " + std::string(e.what())); } + llama_batch batch_view; + int32_t off_next = 0; + int32_t n_batch = llama_n_batch(ctx_tgt); + for (int32_t off = 0; off < batch.size(); off = off_next) { + const int32_t n_tokens = std::min(n_batch, batch.size() - off); + try { + scoped_timer t(t_decode, n_decode); + // TODO @ngxson : maybe handle n_batch == 1 here instead of inside decode() + + batch_view = batch.get_view(off, n_tokens); + bool ok = decode(n_batch, off, batch_view); +#ifdef DEBUG_TIMINGS + llama_synchronize(ctx_tgt); +#endif + + if (ok) { + // move the head of the batch forward with the number of tokens we just processed + off_next = off + n_tokens; + + // on successful decode, restore the original batch size + n_batch = llama_n_batch(ctx_tgt); + } else { + // try again with the updated n_batch + continue; + } + } catch (const std::exception & e) { + SRV_ERR("decode() failed: %s\n", e.what()); + abort_all_slots("decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + try { + scoped_timer t(t_post_decode, n_post_decode); + post_decode(n_tokens, off, batch_view); + } catch (const std::exception & e) { + SRV_ERR("post_decode() failed: %s\n", e.what()); + abort_all_slots("post_decode() failed: " + std::string(e.what())); + break; // stop any further processing + } + + } + } + + void pre_decode() { // apply context-shift if needed // TODO: simplify and improve - for (server_slot & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state == SLOT_STATE_GENERATING && slot.prompt.n_tokens() + 1 >= slot.n_ctx) { if (!params_base.ctx_shift) { // this check is redundant (for good) // we should never get here, because generation should already stopped in process_token() send_error(slot, "context shift is disabled", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (mctx) { @@ -2539,7 +2856,7 @@ private: if (slot.task->is_parent() || slot.task->is_child()) { send_error(slot, "context shift cannot be used for shared prompt", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } // Shift context @@ -2552,7 +2869,10 @@ private: n_keep = std::min(slot.n_ctx - 4, n_keep); const int n_left = slot.prompt.n_tokens() - n_keep; - const int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + int n_discard = slot.task->params.n_discard ? slot.task->params.n_discard : (n_left / 2); + + // ref: https://github.com/ggml-org/llama.cpp/pull/24786 + n_discard = std::clamp(n_discard, 0, std::max(0, n_left - 1)); SLT_WRN(slot, "slot context shift, n_keep = %d, n_left = %d, n_discard = %d\n", n_keep, n_left, n_discard); @@ -2582,28 +2902,28 @@ private: slot.truncated = true; } - } + }); // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch - server_slot * slot_batched = nullptr; + auto & slot_batched = batch.slot_batched; std::vector generating; std::vector drafting; // determine which slots are generating and drafting - for (auto & slot : slots) { + iterate(slots, [&](server_slot & slot) { if (slot.state != SLOT_STATE_GENERATING) { - continue; + return; } // check if we can batch this slot with the previous one if (!slot_batched) { slot_batched = &slot; } else if (!slot_batched->can_batch_with(slot)) { - continue; + return; } generating.push_back(&slot); @@ -2651,7 +2971,7 @@ private: } } } - } + }); // generate the actual drafts (if any) { @@ -2659,9 +2979,7 @@ private: } // make checkpoints if needed - for (auto * slot_ptr : drafting) { - auto & slot = *slot_ptr; - + iterate(drafting, [&](server_slot & slot) { auto & draft = slot.spec_draft; auto & ckpt = slot.spec_ckpt; @@ -2704,38 +3022,42 @@ private: ckpt.update_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); } } - } + }); // update the batch with the sampled/drafted tokens - for (auto * slot_ptr : generating) { - auto & slot = *slot_ptr; - - slot.update_batch(batch); - } + iterate(generating, [&](server_slot & slot) { + slot.handle_last_sampled_token(batch); + }); // process in chunks of params.n_batch int32_t n_batch = llama_n_batch(ctx_tgt); int32_t n_ubatch = llama_n_ubatch(ctx_tgt); - float alora_scale = -1.0f; - size_t alora_disabled_id = 0; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { - for (auto & slot : slots) { + if (params_base.cont_batching || batch.size() == 0) { + bool add_ok = true; // false means the batch is full, skip remaining slots + + iterate(slots, [&](server_slot & slot) { + if (!add_ok || batch.size() >= n_batch) { + return; // batch is full, skip remaining slots + } + if (!slot.is_processing()) { - continue; + return; } // check if we can batch this slot with the previous one if (slot_batched && !slot_batched->can_batch_with(slot)) { - continue; + return; } // check if this is a child slot if (slot.state == SLOT_STATE_WAIT_OTHER) { SLT_DBG(slot, "%s", "waiting for parent slot to complete\n"); - continue; + return; } // this slot still has a prompt to be processed @@ -2743,7 +3065,7 @@ private: const auto & input_tokens = slot.task->tokens; // used to determine the number of tokens added to the batch for the current slot - const auto n_tokens_prev = batch.n_tokens; + const auto n_tokens_prev = batch.size(); // TODO: maybe move branch to outside of this loop in the future if (slot.state == SLOT_STATE_STARTED) { @@ -2779,14 +3101,14 @@ private: send_final_response(slot); slot.release(); - continue; + return; } // TODO: support memory-less logits computation if (slot.task->need_logits() && !llama_get_memory(ctx_tgt)) { send_error(slot, "the current context does not logits computation. skipping", ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (!slot.can_split()) { @@ -2798,7 +3120,7 @@ private: slot.task->n_tokens(), n_ubatch), ERROR_TYPE_SERVER); slot.release(); - continue; + return; } if (slot.task->n_tokens() > slot.n_ctx) { @@ -2809,7 +3131,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } } else { if (slot.task->n_tokens() >= slot.n_ctx) { @@ -2819,7 +3141,7 @@ private: slot.task->n_tokens(), slot.n_ctx), ERROR_TYPE_EXCEED_CONTEXT_SIZE); slot.release(); - continue; + return; } if (slot.task->params.cache_prompt) { @@ -2982,6 +3304,8 @@ private: // restore the context checkpoint it->load_tgt(ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); it->load_dft(ctx_dft.get(), slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + // restore the draft's speculative state + common_speculative_set_state(spec.get(), slot.id, it->data_spec); pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); @@ -3037,8 +3361,8 @@ private: if (!slot.can_split()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.task->n_tokens() > n_batch) { - continue; + if (batch.size() + slot.task->n_tokens() > n_batch) { + return; } } @@ -3118,11 +3442,11 @@ private: has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & spans = slot.task->params.message_spans; + const auto last_user_pos = spans.last_user_message_pos(); // add prompt tokens for processing in the current batch - while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.n_tokens < n_batch) { + while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { // get next token to process llama_token cur_tok = input_tokens[slot.prompt.n_tokens()]; if (cur_tok == LLAMA_TOKEN_NULL) { @@ -3140,19 +3464,16 @@ private: // embedding requires all tokens in the batch to be output; // MTP also wants logits at every prompt position so the // streaming hook can mirror t_h_nextn into ctx_dft. - common_batch_add(batch, + add_ok &= batch.add(slot.id, cur_tok, slot.prompt.tokens.pos_next(), - { slot.id }, slot.need_embd()); slot.prompt.tokens.push_back(cur_tok); slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch exactly before a user message + if (spans.is_user_start(slot.prompt.n_tokens())) { break; } @@ -3179,26 +3500,32 @@ private: } // the number of tokens added to the batch for the current slot - const auto n_tokens_cur = batch.n_tokens - n_tokens_prev; + const auto n_tokens_cur = batch.size() - n_tokens_prev; + + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const bool is_user_start = spans.is_user_start(n_tokens_start); + const bool is_last_user_message = n_tokens_start == last_user_pos; + // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.size() > 0); // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + batch.set_output(batch.size() - 1, true); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.size() - 1; slot.init_sampler(); } else { - // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + // skip ordinary mid-prompt checkpoints, unless the batch starts a user + // message or we are near the end of the prompt + if (!is_user_start && !near_prompt_end) { do_checkpoint = false; } } @@ -3206,29 +3533,6 @@ private: const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - - const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); - - if (do_checkpoint && !is_allowed) { - do_checkpoint = false; - } - } - // nothing to checkpoint yet // TODO: is this check needed? if (do_checkpoint && pos_min < 0) { @@ -3238,8 +3542,8 @@ private: // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); + // no need to create checkpoints that are too close together, unless it's the last user message + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not @@ -3252,20 +3556,20 @@ private: if (!slot_batched) { slot_batched = &slot; } - - if (batch.n_tokens >= n_batch) { - break; - } - } + }); } + } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + // returns true = success ; false = retry with smaller batch size + // throw std::runtime_error on fatal error + bool decode(int32_t & n_batch, int32_t off, llama_batch & batch_view) { + SRV_DBG("n_batch (effective) = %d, off = %d\n", n_batch, off); - auto accept_special_token = [&](server_slot & slot, llama_token token) { - return params_base.special || - slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); - }; + auto & slot_batched = batch.slot_batched; + auto & alora_scale = batch.alora_scale; + auto & alora_disabled_id = batch.alora_disabled_id; + // TODO @ngxson : alora handling is too messy, need to refactor it to be more clear and maintainable if (slot_batched) { // apply lora, only need to do it once per batch common_set_adapter_lora(ctx_tgt, slot_batched->lora); @@ -3280,340 +3584,348 @@ private: llama_set_embeddings(ctx_tgt, slot_batched->need_embd()); } - if (batch.n_tokens == 0) { + if (batch.size() == 0) { SRV_WRN("%s", "no tokens to decode\n"); if (++n_empty_consecutive > 3) { GGML_ABORT("fatal error - please provide logs and repro in %s\n", "https://github.com/ggml-org/llama.cpp/pull/20277"); } + + return true; // nothing to decode } else { n_empty_consecutive = 0; } - int32_t i_next = 0; + const int ret = llama_decode(ctx_tgt, batch_view); - // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i = i_next) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); + metrics.on_decoded(slots); - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + if (ret != 0) { + { + std::string err; - const int ret = llama_decode(ctx_tgt, batch_view); - - metrics.on_decoded(slots); - - if (ret != 0) { - { - std::string err; - - if (n_batch == 1 && ret == 1) { - // TODO: try to terminate only the largest active slot/sequence and continue with the rest - // need to remove the tokens from the current batch too - err = "Context size has been exceeded."; - } - - if (ret == -1) { - err = "Invalid input batch."; - } - - if (ret < -1) { - // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() - err = "Compute error."; - } - - // TODO: handle ret == 2 (abort) when we start aborting - - if (!err.empty()) { - SRV_ERR("%s i = %d, n_batch = %d, ret = %d\n", err.c_str(), i, n_batch, ret); - - for (auto & slot : slots) { - if (slot.is_processing()) { - send_error(slot, err); - slot.release(); - - // note: it's complicated to keep track of how much of the current batch has been - // processed before the error occurred, so we simply clear the entire context - slot.prompt_clear(false); - } - } - - break; - } + if (n_batch == 1 && ret == 1) { + // TODO: try to terminate only the largest active slot/sequence and continue with the rest + // need to remove the tokens from the current batch too + err = "Context size has been exceeded."; } - // retry with half the batch size to try to find a free slot in the KV cache - if (!try_clear_idle_slots()) { - n_batch /= 2; + if (ret == -1) { + err = "Invalid input batch."; } - SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, i = %d, n_batch = %d, ret = %d\n", i, n_batch, ret); + if (ret < -1) { + // TODO: update slot state based on llama_memory_seq_pos_min() and llama_memory_seq_pos_max() + err = "Compute error."; + } - continue; // continue loop of n_batch - } + // TODO: handle ret == 2 (abort) when we start aborting - // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] - // for now, always re-evaluate for simplicity - // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 - if (!common_speculative_process(spec.get(), batch_view)) { - SRV_ERR("%s", "failed to process speculative batch\n"); + if (!err.empty()) { + SRV_ERR("%s off = %d, n_batch = %d, ret = %d\n", err.c_str(), off, n_batch, ret); - // TODO: handle error - break; - } + for (auto & slot : slots) { + if (slot.is_processing()) { + send_error(slot, err); + slot.release(); - // move the head of the batch forward with the number of tokens we just processed - i_next = i + n_tokens; - - // on successful decode, restore the original batch size - n_batch = llama_n_batch(ctx_tgt); - - // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too - for (auto & slot : slots) { - if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { - std::vector children; - for (auto & other : slots) { - if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { - children.push_back(&other); + // note: it's complicated to keep track of how much of the current batch has been + // processed before the error occurred, so we simply clear the entire context + slot.prompt_clear(false); } } - // all children slots should already launched by launch_slots_with_parent_task() - // copy state to the child slots - for (auto & child : children) { - SLT_INF(slot, " - copying state to child %d\n", child->id); - - GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); - - slot.copy_state_to(*child); - child->state = SLOT_STATE_DONE_PROMPT; - } + // stop, do not retry with smaller batch size + throw std::runtime_error(err); } } - for (auto & slot : slots) { - // optionally send prompt processing progress - if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->params.stream && slot.task->params.return_progress) { - send_partial_response(slot, {}, true); + // retry with half the batch size to try to find a free slot in the KV cache + if (!try_clear_idle_slots()) { + n_batch /= 2; + } + + SRV_WRN("failed to find free space in the KV cache, retrying with smaller batch size, off = %d, n_batch = %d, ret = %d\n", off, n_batch, ret); + + return false; // retry with the updated n_batch + } + + // TODO: avoid restoring the draft context and re-evaluating the drafted tokens when not needed [TAG_SPEC_AVOID_DRAFT_REEVAL] + // for now, always re-evaluate for simplicity + // ref: https://github.com/ggml-org/llama.cpp/pull/22728#issuecomment-4400925384 + if (!common_speculative_process(spec.get(), batch_view)) { + SRV_ERR("%s", "failed to process speculative batch\n"); + + // TODO: handle error + throw std::runtime_error("failed to process speculative batch"); + } + + // handle `n_cmpl > 1` tasks - when the main prompt is processed, activate all child tasks too + for (auto & slot : slots) { + if (slot.state == SLOT_STATE_DONE_PROMPT && slot.task->is_parent()) { + std::vector children; + for (auto & other : slots) { + if (other.state == SLOT_STATE_WAIT_OTHER && slot.task->id == other.task->id_parent) { + children.push_back(&other); } } - if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { - continue; // continue loop of slots + // all children slots should already launched by launch_slots_with_parent_task() + // copy state to the child slots + for (auto & child : children) { + SLT_INF(slot, " - copying state to child %d\n", child->id); + + GGML_ASSERT(child->state == SLOT_STATE_WAIT_OTHER); + + slot.copy_state_to(*child); + child->state = SLOT_STATE_DONE_PROMPT; + } + } + } + + return true; + } + + void post_decode(int32_t n_batch_tokens, int32_t off, llama_batch & batch_view) { + // for checking if a given batch index is inside batch_view + auto is_inside_view = [&](int32_t idx) { + return idx >= off && idx < off + n_batch_tokens; + }; + + // TODO @ngxson : it's tricky to make sub-batch compatible with common_sampler_sample_and_accept_n, + // so for now we will throw an error in this case: https://github.com/ggml-org/llama.cpp/issues/24840 + iterate(slots, [&](server_slot & slot) { + for (auto & i : slot.spec_i_batch) { + if (!is_inside_view(i)) { + throw std::runtime_error(string_format("speculative batch index %d is not inside the current sub-batch [%d, %d)", i, off, off + n_batch_tokens)); + } + } + }); + + auto accept_special_token = [&](server_slot & slot, llama_token token) { + return params_base.special || + slot.task->params.sampling.preserved_tokens.find(token) != slot.task->params.sampling.preserved_tokens.end(); + }; + + iterate(slots, [&](server_slot & slot) { + // optionally send prompt processing progress + if (slot.state == SLOT_STATE_PROCESSING_PROMPT || slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->params.stream && slot.task->params.return_progress) { + send_partial_response(slot, {}, true); + } + } + + if (!is_inside_view(slot.i_batch)) { + // the required token not in this sub-batch, skip + return; + } + + if (slot.state == SLOT_STATE_DONE_PROMPT) { + if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { + // prompt evaluated for embedding + send_embedding(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.state == SLOT_STATE_DONE_PROMPT) { - if (slot.task->type == SERVER_TASK_TYPE_EMBEDDING) { - // prompt evaluated for embedding - send_embedding(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - if (slot.task->type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); - slot.release(); - slot.i_batch = -1; - continue; // continue loop of slots - } - - GGML_ASSERT(slot.task->need_sampling()); - - // prompt evaluated for next-token prediction - slot.state = SLOT_STATE_GENERATING; - - if (slot.can_speculate()) { - common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); - } - } else if (slot.state != SLOT_STATE_GENERATING) { - continue; // continue loop of slots + if (slot.task->type == SERVER_TASK_TYPE_RERANK) { + send_rerank(slot, batch_view); + slot.release(); + slot.i_batch = -1; + return; } - if (slot.can_speculate() && !slot.spec_draft.empty()) { - continue; // sample using speculative decoding + GGML_ASSERT(slot.task->need_sampling()); + + // prompt evaluated for next-token prediction + slot.state = SLOT_STATE_GENERATING; + + if (slot.can_speculate()) { + common_speculative_begin(spec.get(), slot.id, slot.prompt.tokens.get_text_tokens()); + } + } else if (slot.state != SLOT_STATE_GENERATING) { + return; + } + + if (slot.can_speculate() && !slot.spec_draft.empty()) { + return; // sample using speculative decoding + } + + // shifted according to the current sub-batch + const int tok_idx = slot.i_batch - off; + + llama_token id; + { + scoped_timer timer(t_sampl, n_sampl); + id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + } + + slot.i_batch = -1; + + common_sampler_accept(slot.smpl.get(), id, true); + + // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement + const int64_t t_now = ggml_time_us(); + + slot.n_decoded += 1; + + if (slot.n_decoded == 1) { + slot.t_start_generation = t_now; + slot.t_print_last = t_now; + slot.n_decoded_last = 0; + slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; + metrics.on_prompt_eval(slot); + } + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + completion_token_output result; + result.tok = id; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs + + if (slot.task->params.sampling.n_probs > 0) { + populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); + } + + if (!process_token(result, slot)) { + // release slot because of stop condition + slot.print_timings(); + send_final_response(slot); + metrics.on_prediction(slot); + slot.release(); + + return; + } + + slot.print_timings_tg(); + }); + + // speculative decoding - main model sample and accept + iterate(slots, [&](server_slot & slot) { + if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { + return; + } + + // save the original draft size + const size_t n_draft = slot.spec_draft.size(); + + GGML_ASSERT(n_draft > 0); + + // verify and try to accept the draft + { + // save the sampler sampler state in case we need to restore it + common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); + + GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); + auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); + slot.spec_i_batch.clear(); + + GGML_ASSERT(accepted.size() >= 1); + + const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); + + const bool use_ckpt_tgt = + ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || + (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); + + // check for partial draft acceptance + if (n_rollback > 0) { + if (use_ckpt_tgt) { + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); + } + + // partial acceptance is not supported by the context -> truncate the draft and restore the state + slot.spec_draft = std::move(accepted); + + const auto & ckpt = slot.spec_ckpt; + + SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); + + { + ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); + } + + if (slot.ctx_dft) { + ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); + + common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); + } + + slot.prompt.tokens.keep_first(ckpt.n_tokens); + slot.smpl = std::move(smpl_save); + + return; + } } - const int tok_idx = slot.i_batch - i; + if (trace > 0) { + SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); + } - llama_token id = common_sampler_sample(slot.smpl.get(), slot.ctx_tgt, tok_idx); + common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - slot.i_batch = -1; + slot.spec_draft = std::move(accepted); + } - common_sampler_accept(slot.smpl.get(), id, true); + const int64_t t_now = ggml_time_us(); - // here we have synchronized the llama_context (due to the sampling above), so we can do time measurement - const int64_t t_now = ggml_time_us(); + const auto ids = std::move(slot.spec_draft); + + slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; + + // update how many tokens out of those tested were accepted + slot.n_draft_accepted += ids.size() - 1; + slot.n_draft_verif_steps += 1; + + if (slot.n_accepted_per_pos.empty()) { + slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); + } + for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { + slot.n_accepted_per_pos[i]++; + } + + // add accepted tokens to the prompt + slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); + slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); + + slot.sampled = ids.back(); // last accepted token + SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); + + common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); + if (slot.ctx_dft) { + common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); + } + + for (size_t i = 0; i < ids.size(); ++i) { + completion_token_output result; + + result.tok = ids[i]; + result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); + result.prob = 1.0f; // set later + + // TODO: set result.probs slot.n_decoded += 1; - if (slot.n_decoded == 1) { - slot.t_start_generation = t_now; - slot.t_print_last = t_now; - slot.n_decoded_last = 0; - slot.t_prompt_processing = (slot.t_start_generation - slot.t_start_process_prompt) / 1e3; - metrics.on_prompt_eval(slot); - } - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - completion_token_output result; - result.tok = id; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // TODO: set it here instead of doing inside populate_token_probs - - if (slot.task->params.sampling.n_probs > 0) { - populate_token_probs(slot, result, slot.task->params.post_sampling_probs, params_base.special, tok_idx); - } - if (!process_token(result, slot)) { - // release slot because of stop condition slot.print_timings(); send_final_response(slot); metrics.on_prediction(slot); slot.release(); - continue; + return; } - - slot.print_timings_tg(); } - // speculative decoding - main model sample and accept - for (auto & slot : slots) { - if (slot.state != SLOT_STATE_GENERATING || !slot.can_speculate() || slot.spec_draft.empty()) { - continue; - } + slot.print_timings_tg(); - // save the original draft size - const size_t n_draft = slot.spec_draft.size(); - - GGML_ASSERT(n_draft > 0); - - // verify and try to accept the draft - { - // save the sampler sampler state in case we need to restore it - common_sampler_ptr smpl_save(common_sampler_clone(slot.smpl.get())); - - GGML_ASSERT(slot.spec_i_batch.size() == n_draft + 1); - auto accepted = common_sampler_sample_and_accept_n(slot.smpl.get(), slot.ctx_tgt, slot.spec_i_batch, slot.spec_draft); - slot.spec_i_batch.clear(); - - GGML_ASSERT(accepted.size() >= 1); - - const uint32_t n_rollback = slot.spec_draft.size() + 1 - accepted.size(); - - const bool use_ckpt_tgt = - ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_FULL || - (ctx_tgt_seq_rm_type == COMMON_CONTEXT_SEQ_RM_TYPE_RS && n_rollback > llama_n_rs_seq(ctx_tgt)); - - // check for partial draft acceptance - if (n_rollback > 0) { - if (use_ckpt_tgt) { - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens (restore checkpoint)\n", accepted.size() - 1, slot.spec_draft.size()); - } - - // partial acceptance is not supported by the context -> truncate the draft and restore the state - slot.spec_draft = std::move(accepted); - - const auto & ckpt = slot.spec_ckpt; - - SLT_DBG(slot, "restoring speculative checkpoint (pos_min = %d, pos_max = %d, size = %zu)\n", ckpt.pos_min, ckpt.pos_max, ckpt.size()); - - { - ckpt.load_tgt(slot.ctx_tgt, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_tgt, slot.id, ckpt.pos_max + 1, -1); - } - - if (slot.ctx_dft) { - ckpt.load_dft(slot.ctx_dft, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - - common_context_seq_rm(slot.ctx_dft, slot.id, ckpt.pos_max + 1, -1); - } - - slot.prompt.tokens.keep_first(ckpt.n_tokens); - slot.smpl = std::move(smpl_save); - - continue; - } - } - - if (trace > 0) { - SLT_INF(slot, "accepted %2zu/%2zu draft tokens\n", accepted.size() - 1, n_draft); - } - - common_speculative_accept(spec.get(), slot.id, accepted.size() - 1); - - slot.spec_draft = std::move(accepted); - } - - const int64_t t_now = ggml_time_us(); - - const auto ids = std::move(slot.spec_draft); - - slot.t_token_generation = std::max(1, t_now - slot.t_start_generation) / 1e3; - - // update how many tokens out of those tested were accepted - slot.n_draft_accepted += ids.size() - 1; - slot.n_draft_verif_steps += 1; - - if (slot.n_accepted_per_pos.empty()) { - slot.n_accepted_per_pos.resize(common_speculative_n_max(¶ms_base.speculative), 0); - } - for (size_t i = 0; i < ids.size() - 1 && i < slot.n_accepted_per_pos.size(); ++i) { - slot.n_accepted_per_pos[i]++; - } - - // add accepted tokens to the prompt - slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); - slot.prompt.tokens.insert({ids.begin(), ids.end() - 1}); - - slot.sampled = ids.back(); // last accepted token - SLT_DBG(slot, "add accepted tokens: sampled=%d, ids.size=%zu, n_draft=%zu\n", slot.sampled, ids.size(), n_draft); - - common_context_seq_rm(slot.ctx_tgt, slot.id, slot.prompt.tokens.pos_next(), -1); - if (slot.ctx_dft) { - common_context_seq_rm(slot.ctx_dft, slot.id, slot.prompt.tokens.pos_next(), -1); - } - - for (size_t i = 0; i < ids.size(); ++i) { - completion_token_output result; - - result.tok = ids[i]; - result.text_to_send = common_token_to_piece(slot.ctx_tgt, result.tok, accept_special_token(slot, result.tok)); - result.prob = 1.0f; // set later - - // TODO: set result.probs - - slot.n_decoded += 1; - - if (!process_token(result, slot)) { - slot.print_timings(); - send_final_response(slot); - metrics.on_prediction(slot); - slot.release(); - - break; - } - } - - slot.print_timings_tg(); - - SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); - } - } - - SRV_DBG("%s", "run slots completed\n"); + SLT_DBG(slot, "accepted %d/%d draft tokens, new n_tokens = %d\n", (int) ids.size() - 1, (int) n_draft, slot.prompt.n_tokens()); + }); } int get_slot_n_ctx() { @@ -3670,7 +3982,6 @@ server_context_meta server_context::get_meta() const { /* has_inp_audio */ impl->chat_params.allow_audio, /* has_inp_video */ impl->chat_params.allow_video, /* json_ui_settings */ impl->json_ui_settings, - /* json_webui_settings */ impl->json_webui_settings, // Deprecated /* slot_n_ctx */ impl->get_slot_n_ctx(), /* pooling_type */ llama_pooling_type(impl->ctx_tgt), @@ -3721,58 +4032,16 @@ struct server_res_generator : server_http_res { } }; -void server_context::on_sleeping_changed(std::function callback) { - impl->queue_tasks.on_sleeping_state(std::move(callback)); +void server_context::set_state_callback(server_state_callback_t callback) { + impl->callback_state = std::move(callback); + impl->queue_tasks.on_sleeping_state([this](bool sleeping) { + if (sleeping) { + impl->callback_state(SERVER_STATE_SLEEPING, {}); + } + // for sleeping == false, event is emitted by load_model() + }); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, - const std::string & prompt, - const std::vector & files, - const llama_vocab * vocab, - mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; - - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); - - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } - } - - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); - - const std::string prefix = prompt.substr(0, (size_t) byte_pos); - - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } - - GGML_ASSERT(n_prefix_media <= files.size()); - - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); - - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); - } - - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); - } - - return result; -} - - // // server_routes // @@ -3820,6 +4089,10 @@ std::unique_ptr server_routes::handle_completions_impl( // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks + // message delimiters for checkpointing + auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array())); + delimiters.tokenize(ctx_server.vocab); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -3833,16 +4106,7 @@ std::unique_ptr server_routes::handle_completions_impl( meta->logit_bias_eog, data); - const auto message_spans = json_value(data, "message_spans", json::array()); - if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( - message_spans, - prompt.get(), - files, - ctx_server.vocab, - ctx_server.mctx); - } + task.params.message_spans = task.tokens.find_message_spans(delimiters); task.id_slot = json_value(data, "id_slot", -1); @@ -4283,19 +4547,15 @@ void server_routes::init_routes() { { "endpoint_slots", params.endpoint_slots }, { "endpoint_props", params.endpoint_props }, { "endpoint_metrics", params.endpoint_metrics }, - // New keys - { "ui", params.ui }, - { "ui_settings", meta->json_ui_settings }, - // Deprecated: use ui/ui_settings instead (kept for backward compat) - { "webui", params.webui }, - { "webui_settings", meta->json_webui_settings }, + { "ui", params.ui }, + { "ui_settings", meta->json_ui_settings }, { "chat_template", tmpl_default }, { "chat_template_caps", meta->chat_template_caps }, { "bos_token", meta->bos_token_str }, { "eos_token", meta->eos_token_str }, { "build_info", meta->build_info }, { "is_sleeping", queue_tasks.is_sleeping() }, - { "cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy }, + { "cors_proxy_enabled", params.ui_mcp_proxy }, }; if (params.use_jinja) { if (!tmpl_tools.empty()) { diff --git a/tools/server/server-context.h b/tools/server/server-context.h index 0e84785af4..952f825f72 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -22,8 +22,7 @@ struct server_context_meta { bool has_inp_image; bool has_inp_audio; bool has_inp_video; - json json_ui_settings; // Primary: new name - json json_webui_settings; // Deprecated: use json_ui_settings instead (kept for backward compat) + json json_ui_settings; int slot_n_ctx; enum llama_pooling_type pooling_type; @@ -53,6 +52,33 @@ struct server_context_meta { uint64_t model_size; }; +enum server_state { + SERVER_STATE_DOWNLOADING, + SERVER_STATE_LOADING, + SERVER_STATE_READY, + SERVER_STATE_SLEEPING, +}; + +static std::string server_state_to_str(server_state state) { + switch (state) { + case SERVER_STATE_DOWNLOADING: return "downloading"; + case SERVER_STATE_LOADING: return "loading"; + case SERVER_STATE_READY: return "ready"; + case SERVER_STATE_SLEEPING: return "sleeping"; + default: GGML_ASSERT(false && "invalid server_state"); + } +} + +static server_state server_state_from_str(const std::string & str) { + if (str == "downloading") return SERVER_STATE_DOWNLOADING; + if (str == "loading") return SERVER_STATE_LOADING; + if (str == "ready") return SERVER_STATE_READY; + if (str == "sleeping") return SERVER_STATE_SLEEPING; + GGML_ASSERT(false && "invalid server_state string"); +} + +using server_state_callback_t = std::function; + struct server_context { std::unique_ptr impl; @@ -80,9 +106,8 @@ struct server_context { // not thread-safe, should only be used from the main thread server_context_meta get_meta() const; - // register a callback to be called when sleeping state changes - // must be set before load_model() is called - void on_sleeping_changed(std::function callback); + // note: must be set before load_model() is called + void set_state_callback(server_state_callback_t callback); }; diff --git a/tools/server/server-cors-proxy.h b/tools/server/server-cors-proxy.h index 2af0c7e1c2..53a6909ed2 100644 --- a/tools/server/server-cors-proxy.h +++ b/tools/server/server-cors-proxy.h @@ -7,9 +7,18 @@ #include #include #include +#include +#include #include "server-http.h" +static std::string proxy_header_to_lower(std::string header) { + std::transform(header.begin(), header.end(), header.begin(), [](unsigned char c) { + return std::tolower(c); + }); + return header; +} + static server_http_res_ptr proxy_request(const server_http_req & req, std::string method) { std::string target_url = req.get_param("url"); common_http_url parsed_url = common_http_parse_url(target_url); @@ -33,11 +42,18 @@ static server_http_res_ptr proxy_request(const server_http_req & req, std::strin SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str()); std::map headers; + const std::string proxy_header_prefix = "x-llama-server-proxy-header-"; for (auto [key, value] : req.headers) { - auto new_key = key; - if (string_starts_with(new_key, "x-proxy-header-")) { - string_replace_all(new_key, "x-proxy-header-", ""); + const std::string lowered_key = proxy_header_to_lower(key); + if (!string_starts_with(lowered_key, proxy_header_prefix)) { + continue; } + + auto new_key = key.substr(proxy_header_prefix.size()); + if (new_key.empty()) { + continue; + } + headers[new_key] = value; } diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 5defee1f5e..4f2abab00c 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -492,6 +492,8 @@ using server_http_req_ptr = std::unique_ptr; static void process_handler_response(server_http_req_ptr && request, server_http_res_ptr & response, httplib::Response & res) { if (response->is_stream()) { res.status = response->status; + // Tell Nginx to not buffer any streamed response + response->headers["X-Accel-Buffering"] = "no"; set_headers(res, response->headers); const std::string content_type = response->content_type; // convert to shared_ptr as both chunked_content_provider() and on_complete() need to use it diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 7aaad69261..bb2f43a10d 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -1,5 +1,6 @@ #include "server-common.h" #include "server-models.h" +#include "server-context.h" #include "build-info.h" #include "preset.h" @@ -44,9 +45,7 @@ extern char **environ; #define DEFAULT_STOP_TIMEOUT 10 // seconds #define CMD_ROUTER_TO_CHILD_EXIT "cmd_router_to_child:exit" -#define CMD_CHILD_TO_ROUTER_READY "cmd_child_to_router:ready" // also sent when waking up from sleep -#define CMD_CHILD_TO_ROUTER_SLEEP "cmd_child_to_router:sleep" -#define CMD_CHILD_TO_ROUTER_INFO "cmd_child_to_router:info:" // followed by json string +#define CMD_CHILD_TO_ROUTER_STATE "cmd_child_to_router:state:" // followed by json string // address for child process, this is needed because router may run on 0.0.0.0 // ref: https://github.com/ggml-org/llama.cpp/issues/17862 @@ -65,6 +64,17 @@ struct server_subproc { return sproc.has_value() && subprocess_alive(&sproc.value()); } + void request_exit() { + if (sproc.has_value()) { + FILE * stdin_file = subprocess_stdin(&sproc.value()); + if (stdin_file) { + fprintf(stdin_file, "%s\n", CMD_ROUTER_TO_CHILD_EXIT); + fflush(stdin_file); + } + } + stopped.store(true, std::memory_order_relaxed); + } + void terminate() { if (!sproc.has_value()) { return; @@ -213,8 +223,8 @@ void server_model_meta::update_caps() { "LLAMA_ARG_HF_REPO_FILE", }); params.offline = true; - // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + common_models_handler_apply(handler, params); // note: this won't download the model because offline=true if (params.mmproj.path.empty()) { multimodal = { false, false }; } else { @@ -324,7 +334,7 @@ void server_models::notify_sse(const std::string & event, const std::string & mo } void server_models::load_models() { - // Phase 1: load presets from all sources โ€” pure I/O, no lock needed + // Phase 1: load presets from all sources - pure I/O, no lock needed // 1. cached models common_presets cached_models = ctx_preset.load_from_cache(); SRV_INF("Loaded %zu cached model presets\n", cached_models.size()); @@ -377,7 +387,7 @@ void server_models::load_models() { return source_map.count(name) ? source_map.at(name) : SERVER_MODEL_SOURCE_PRESET; }; - // Helpers that read `mapping` โ€” must be called while holding the lock. + // Helpers that read `mapping` - must be called while holding the lock. std::unordered_set custom_names; for (const auto & [name, preset] : custom_presets) custom_names.insert(name); auto join_set = [](const std::set & s) { @@ -443,6 +453,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -523,7 +534,7 @@ void server_models::load_models() { } } - // join outside the lock โ€” monitoring thread calls update_status (needs lock) + // join outside the lock - monitoring thread calls update_status (needs lock) lk.unlock(); for (auto & th : threads_to_join) th.join(); lk.lock(); @@ -609,6 +620,7 @@ void server_models::load_models() { /* last_used */ 0, /* args */ std::vector(), /* loaded_info */ {}, + /* progress */ {}, /* exit_code */ 0, /* stop_timeout */ DEFAULT_STOP_TIMEOUT, /* multimodal */ mtmd_caps{false, false}, @@ -621,7 +633,7 @@ void server_models::load_models() { apply_stop_timeout(); - // clear reload flag before unlocking for autoload โ€” load() blocks on !is_reloading, + // clear reload flag before unlocking for autoload - load() blocks on !is_reloading, // so clearing it here (while still locked) prevents a deadlock in the autoload calls below is_reloading = false; cv.notify_all(); @@ -814,17 +826,23 @@ void server_models::unload_lru() { } void server_models::load(const std::string & name) { - if (!has_model(name)) { - throw std::runtime_error("model name=" + name + " is not found"); + load(name, load_options{}); +} + +void server_models::load(const std::string & name, const load_options & opts) { + if (!opts.custom_meta.has_value()) { + if (!has_model(name)) { + throw std::runtime_error("model name=" + name + " is not found"); + } + unload_lru(); } - unload_lru(); std::unique_lock lk(mutex); // edge case: block until any in-progress reload has finished so we always load // against the freshest preset and a consistent mapping state cv.wait(lk, [this]() { return !is_reloading; }); - auto meta = mapping[name].meta; + auto meta = opts.custom_meta.has_value() ? *opts.custom_meta : mapping[name].meta; if (meta.status != SERVER_MODEL_STATUS_UNLOADED) { SRV_INF("model %s is not ready\n", name.c_str()); return; @@ -868,6 +886,12 @@ void server_models::load(const std::string & name) { std::vector child_env = base_env; // copy child_env.push_back("LLAMA_SERVER_ROUTER_PORT=" + std::to_string(base_params.port)); + if (opts.mode == SERVER_CHILD_MODE_DOWNLOAD) { + inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; + child_env.push_back("LLAMA_SERVER_CHILD_MODE=download"); + child_env.push_back("LLAMA_ARG_HF_REPO=" + name); + } + SRV_INF("%s", "spawning server instance with args:\n"); for (const auto & arg : child_args) { SRV_INF(" %s\n", arg.c_str()); @@ -885,13 +909,17 @@ void server_models::load(const std::string & name) { if (result != 0) { throw std::runtime_error("failed to spawn server instance"); } - - inst.stdin_file = subprocess_stdin(&inst.subproc->get()); } // start a thread to manage the child process // captured variables are guaranteed to be destroyed only after the thread is joined - inst.th = std::thread([this, name, child_proc = inst.subproc, port = inst.meta.port, stop_timeout = inst.meta.stop_timeout]() { + inst.th = std::thread([ + this, name, + child_proc = inst.subproc, + port = inst.meta.port, + stop_timeout = inst.meta.stop_timeout, + child_mode = opts.mode + ]() { FILE * stdin_file = subprocess_stdin(&child_proc->get()); FILE * stdout_file = subprocess_stdout(&child_proc->get()); // combined stdout/stderr @@ -904,12 +932,8 @@ void server_models::load(const std::string & name) { while (fgets(buffer, vec_buf.size(), stdout_file) != nullptr) { LOG("[%5d] %s", port, buffer); std::string str(buffer); - if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_READY)) { - this->update_status(name, SERVER_MODEL_STATUS_LOADED, 0); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_INFO)) { - this->update_loaded_info(name, str); - } else if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_SLEEP)) { - this->update_status(name, SERVER_MODEL_STATUS_SLEEPING, 0); + if (string_starts_with(buffer, CMD_CHILD_TO_ROUTER_STATE)) { + this->handle_child_state(name, str); } } } else { @@ -928,7 +952,7 @@ void server_models::load(const std::string & name) { return is_stopping() || child_proc->stopped.load(std::memory_order_acquire); }); } - // child crashed or finished on its own โ€” skip graceful shutdown sequence + // child crashed or finished on its own, skip graceful shutdown sequence if (child_proc->stopped.load(std::memory_order_acquire)) { return; } @@ -976,7 +1000,14 @@ void server_models::load(const std::string & name) { subprocess_destroy(&child_proc->get()); // update status and exit code - this->update_status(name, SERVER_MODEL_STATUS_UNLOADED, exit_code); + if (child_mode == SERVER_CHILD_MODE_DOWNLOAD) { + // instance will be cleaned up on next load_models() call + } else { + this->update_status(name, { + SERVER_MODEL_STATUS_UNLOADED, + exit_code + }); + } SRV_INF("instance name=%s exited with status %d\n", name.c_str(), exit_code); }); @@ -984,7 +1015,7 @@ void server_models::load(const std::string & name) { { auto & old_instance = mapping[name]; // old process should have exited already, but just in case, we clean it up here - if (old_instance.subproc->is_alive()) { + if (old_instance.subproc && old_instance.subproc->is_alive()) { SRV_WRN("old process for model name=%s is still alive, this is unexpected\n", name.c_str()); old_instance.subproc->terminate(); // force kill } @@ -1001,90 +1032,13 @@ void server_models::load(const std::string & name) { cv.notify_all(); } -// callback for model downloading functionality -struct server_models_download_res : public common_download_callback { - common_params_model model; - common_download_opts opts; - - std::function should_stop; - std::function on_progress; - - bool is_ok = false; - - bool run() { - try { - common_download_model(model, opts); - is_ok = true; - } catch (const std::exception & e) { - SRV_ERR("download failed for model name=%s: %s\n", model.name.c_str(), e.what()); - is_ok = false; - } - return is_ok; - } - void on_start(const common_download_progress & p) override { - on_progress(p); - } - void on_update(const common_download_progress & p) override { - on_progress(p); - } - void on_done(const common_download_progress &, bool ok) override { - is_ok = ok; - } - bool is_cancelled() const override { - return should_stop(); - } -}; - -void server_models::download(common_params_model && model, common_download_opts && opts) { - std::string name = model.name; - GGML_ASSERT(name == model.hf_repo); - - std::unique_lock lk(mutex); - if (mapping.find(name) != mapping.end()) { - throw std::runtime_error("model name=" + name + " already exists"); - } - - instance_t inst; - inst.meta.name = name; - inst.meta.status = SERVER_MODEL_STATUS_DOWNLOADING; - inst.subproc = std::make_shared(); - - auto dl = std::make_unique(); - dl->model = model; // copy - dl->opts = opts; // copy - - dl->should_stop = [sp = inst.subproc]() { - return sp->stopped.load(std::memory_order_relaxed); - }; - - dl->on_progress = [this, name](const common_download_progress & p) { - update_download_progress(name, p, false); - }; - - inst.th = std::thread([this, dl = std::move(dl)]() { - dl->opts.callback = dl.get(); - bool ok = dl->run(); - SRV_INF("download finished for model name=%s with status=%s\n", - dl->model.name.c_str(), ok ? "success" : "failure"); - update_download_progress(dl->model.name, {}, true, ok); - // need_reload is set inside update_download_progress under the mutex; - // the next load_models() call will clean up this instance - }); - - mapping[name] = std::move(inst); - notify_sse("status_update", name, { - {"status", server_model_status_to_string(SERVER_MODEL_STATUS_DOWNLOADING)}, - }); - cv.notify_all(); -} - void server_models::unload(const std::string & name) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { SRV_INF("cancelling download for model name=%s\n", name.c_str()); - it->second.subproc->stopped.store(true, std::memory_order_relaxed); + it->second.subproc->request_exit(); // for convenience, we wait the status change here wait(lk, name, [](const server_model_meta & new_meta) { return new_meta.status != SERVER_MODEL_STATUS_DOWNLOADING; @@ -1130,21 +1084,33 @@ void server_models::unload_all() { } } -void server_models::update_status(const std::string & name, server_model_status status, int exit_code) { +void server_models::update_status(const std::string & name, const update_status_args & args) { std::unique_lock lk(mutex); auto it = mapping.find(name); if (it != mapping.end()) { auto & meta = it->second.meta; - meta.status = status; - meta.exit_code = exit_code; + meta.status = args.status; + meta.exit_code = args.exit_code; + if (!args.loaded_info.is_null()) { + meta.loaded_info = args.loaded_info; + } + if (!args.progress.is_null()) { + meta.progress = args.progress; + } } // broadcast status change to SSE { json data = { - {"status", server_model_status_to_string(status)}, + {"status", server_model_status_to_string(args.status)}, }; - if (status == SERVER_MODEL_STATUS_UNLOADED) { - data["exit_code"] = exit_code; + if (args.status == SERVER_MODEL_STATUS_UNLOADED) { + data["exit_code"] = args.exit_code; + } + if (!args.loaded_info.is_null()) { + data["info"] = args.loaded_info; + } + if (!args.progress.is_null()) { + data["progress"] = args.progress; } // note: notify_sse doesn't acquire the lock, so no deadlock here notify_sse("status_change", name, data); @@ -1152,29 +1118,6 @@ void server_models::update_status(const std::string & name, server_model_status cv.notify_all(); } -void server_models::update_loaded_info(const std::string & name, std::string & raw_info) { - if (!string_starts_with(raw_info, CMD_CHILD_TO_ROUTER_INFO)) { - SRV_WRN("invalid loaded info format from child for model name=%s: %s\n", name.c_str(), raw_info.c_str()); - return; - } - - json info; - try { - info = json::parse(raw_info.substr(strlen(CMD_CHILD_TO_ROUTER_INFO))); - } catch (const std::exception & e) { - SRV_WRN("failed to parse loaded info from child for model name=%s: %s\n", name.c_str(), e.what()); - return; - } - - std::unique_lock lk(mutex); - auto it = mapping.find(name); - if (it != mapping.end()) { - auto & meta = it->second.meta; - meta.loaded_info = info; - } - cv.notify_all(); -} - void server_models::update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok) { json curr; { @@ -1207,37 +1150,65 @@ void server_models::update_download_progress(const std::string & name, const com } bool server_models::remove(const std::string & name) { - auto meta = get_meta(name); + // do everything under one lock acquisition; avoid get_meta() / + // unload() because they can trigger load_models() which erases + // transient DOWNLOADING / DOWNLOADED entries as a side-effect + std::unique_lock lk(mutex); - if (!meta.has_value()) { + auto it = mapping.find(name); + if (it == mapping.end()) { throw std::runtime_error("model name=" + name + " is not found"); } - if (meta->source != SERVER_MODEL_SOURCE_CACHE) { + if (it->second.meta.source != SERVER_MODEL_SOURCE_CACHE) { throw std::runtime_error("model name=" + name + " is not removable (not from cache)"); } - unload(name); // cancel download or stop running instance - { - std::unique_lock lk(mutex); - // a cancelled download lands on DOWNLOADED; a stopped instance lands on UNLOADED - wait(lk, name, [](const server_model_meta & new_meta) { - return new_meta.status == SERVER_MODEL_STATUS_UNLOADED - || new_meta.status == SERVER_MODEL_STATUS_DOWNLOADED; - }); - // join before erasing - after status reaches UNLOADED/DOWNLOADED the thread no - // longer acquires this mutex, so joining while holding it is safe - if (mapping[name].th.joinable()) { - mapping[name].th.join(); + if (it->second.meta.status == SERVER_MODEL_STATUS_DOWNLOADING) { + // cancel in-flight download + SRV_INF("cancelling download for model name=%s\n", name.c_str()); + it->second.subproc->request_exit(); + } else if (it->second.meta.is_running()) { + // stop running instance + SRV_INF("stopping model instance name=%s\n", name.c_str()); + stopping_models.insert(name); + if (it->second.meta.status == SERVER_MODEL_STATUS_LOADING) { + it->second.subproc->terminate(); } - // remove the model from disk (hold lock to prevent concurrent load) - bool ok = common_download_remove(name); - if (ok) { - mapping.erase(name); - } - SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "failed"); - notify_sse("model_remove", name, {}); - return ok; + cv_stop.notify_all(); } + + // wait until the monitoring thread finishes + wait(lk, name, [](const server_model_meta & meta) { + return meta.status == SERVER_MODEL_STATUS_UNLOADED + || meta.status == SERVER_MODEL_STATUS_DOWNLOADED; + }); + + // re-find after wait - load_models() may have erased the entry during the wait + it = mapping.find(name); + if (it == mapping.end()) { + // load_models() already joined the thread and erased the entry; + // we just need to clean up the cached files on disk + lk.unlock(); + bool ok = common_download_remove(name); + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; + } + + // join before erasing - thread no longer acquires this mutex + if (it->second.th.joinable()) { + it->second.th.join(); + } + + // remove from disk (best-effort: cancelled downloads may have no cached files) + bool ok = common_download_remove(name); + mapping.erase(name); + if (!ok) { + SRV_WRN("removing model name=%s from disk returned false (no cached files?)\n", name.c_str()); + } + SRV_INF("removing model name=%s from cache (%s)\n", name.c_str(), ok ? "succeeded" : "partial"); + notify_sse("model_remove", name, {}); + return true; } void server_models::wait(const std::string & name, std::function predicate) { @@ -1252,7 +1223,9 @@ void server_models::wait(std::unique_lock & lk, const std::string & return predicate(it->second.meta); } - return false; + // model was removed from mapping by another code path (e.g. load_models()). + // nothing left to wait for - tell the caller to proceed. + return true; }); } @@ -1323,21 +1296,169 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co return proxy; } -bool server_models::is_child_server() { +void server_models::handle_child_state(const std::string & name, const std::string & raw_input) { + server_state state; + json payload; + + try { + json data = json::parse(raw_input.substr(strlen(CMD_CHILD_TO_ROUTER_STATE))); + state = server_state_from_str(json_value(data, "state", std::string())); + payload = json_value(data, "payload", json{}); + } catch (const std::exception & e) { + SRV_ERR("failed to parse child state update for name=%s: %s\n", name.c_str(), e.what()); + return; + } + + switch (state) { + case SERVER_STATE_DOWNLOADING: + { + std::string result = json_value(payload, "result", std::string()); + std::string url = json_value(payload, "url", std::string()); + auto request_exit = [&]() { + std::lock_guard lk(mutex); + auto it = mapping.find(name); + if (it != mapping.end()) { + return it->second.subproc->request_exit(); + } + }; + if (result == "download_finished") { + update_download_progress(name, {}, true, true); + request_exit(); + } else if (result == "download_failed") { + update_download_progress(name, {}, true, false); + request_exit(); + } else if (!url.empty()) { + common_download_progress p; + p.url = url; + p.downloaded = json_value(payload, "downloaded", (size_t)0); + p.total = json_value(payload, "total", (size_t)0); + update_download_progress(name, p, false); + } + } break; + case SERVER_STATE_LOADING: + { + update_status(name, { + SERVER_MODEL_STATUS_LOADING, + 0, + nullptr, // no loaded_info yet + payload, + }); + } break; + case SERVER_STATE_READY: + { + update_status(name, { + SERVER_MODEL_STATUS_LOADED, + 0, + // note: payload can be empty if this is a wakeup from sleep + payload.size() > 0 ? payload : nullptr, + {}, // reset progress info + }); + } break; + case SERVER_STATE_SLEEPING: + { + update_status(name, { SERVER_MODEL_STATUS_SLEEPING }); + } break; + default: + // should never happen, but just in case + GGML_ASSERT(false && "unexpected state from child server"); + } +} + +// +// server_child +// + +bool server_child::is_child() { const char * router_port = std::getenv("LLAMA_SERVER_ROUTER_PORT"); return router_port != nullptr; } -std::thread server_models::setup_child_server(const std::function & shutdown_handler, const json & model_info) { - // send a notification to the router server that a model instance is ready - common_log_pause(common_log_main()); - fflush(stdout); - fprintf(stdout, "%s\n", CMD_CHILD_TO_ROUTER_READY); - fflush(stdout); - fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_INFO, safe_json_to_str(model_info).c_str()); - fflush(stdout); - common_log_resume(common_log_main()); +server_child_mode server_child::get_mode() { + const char * mode = std::getenv("LLAMA_SERVER_CHILD_MODE"); + std::string mode_str(mode ? mode : ""); + if (mode_str == "download") { + return SERVER_CHILD_MODE_DOWNLOAD; + } else { + return SERVER_CHILD_MODE_NORMAL; + } +} +struct server_download_state : public common_download_callback { + server_child * self; + std::function should_stop; + std::atomic last_progress_time{0}; // multiple files downloading in different threads + bool is_ok = false; + + server_download_state(server_child * s) : self(s) {} + + bool run(common_params & params) { + try { + common_models_handler handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + common_models_handler_apply(handler, params, this); + is_ok = true; + } catch (const std::exception & e) { + auto model_name = params.model.get_name(); + SRV_ERR("download failed for model name=%s: %s\n", model_name.c_str(), e.what()); + is_ok = false; + } + return is_ok; + } + void on_progress(const common_download_progress & p) { + json data = { + {"url", p.url}, + {"downloaded", p.downloaded}, + {"total", p.total}, + }; + self->notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), data); + } + void on_start(const common_download_progress & p) override { + on_progress(p); + } + void on_update(const common_download_progress & p) override { + int64_t now = ggml_time_ms(); + // throttle progress updates to avoid flooding logs + if (now - last_progress_time.load(std::memory_order_relaxed) >= 100) { + on_progress(p); + last_progress_time.store(now, std::memory_order_relaxed); + } + } + void on_done(const common_download_progress & p, bool) override { + on_progress(p); + } + bool is_cancelled() const override { + return should_stop ? should_stop() : false; + } +}; + +int server_child::run_download(common_params & params) { + auto cancelled = std::make_shared>(false); + + // monitor stdin for cancellation command from the router + std::thread signal_thread = setup([cancelled](int) { + cancelled->store(true, std::memory_order_relaxed); + }); + + server_download_state dl(this); + dl.should_stop = [cancelled]() { + return cancelled->load(std::memory_order_relaxed); + }; + + bool ok = dl.run(params); + + notify_to_router(server_state_to_str(SERVER_STATE_DOWNLOADING), { + {"result", ok ? "download_finished" : "download_failed"}, + }); + + // router should send CMD_ROUTER_TO_CHILD_EXIT after receiving the result + if (signal_thread.joinable()) { + signal_thread.join(); + } + + SRV_INF("download completed %s\n", ok ? "successfully" : "with errors"); + return 0; +} + +std::thread server_child::setup(const std::function & shutdown_handler) { // setup thread for monitoring stdin return std::thread([shutdown_handler]() { // wait for EOF on stdin @@ -1363,10 +1484,15 @@ std::thread server_models::setup_child_server(const std::function & s }); } -void server_models::notify_router_sleeping_state(bool is_sleeping) { +void server_child::notify_to_router(const std::string & state, const json & payload) { + json data = { + {"state", state}, + {"payload", payload}, + }; + std::lock_guard lk(mtx_stdout); common_log_pause(common_log_main()); fflush(stdout); - fprintf(stdout, "%s\n", is_sleeping ? CMD_CHILD_TO_ROUTER_SLEEP : CMD_CHILD_TO_ROUTER_READY); + fprintf(stdout, "%s%s\n", CMD_CHILD_TO_ROUTER_STATE, safe_json_to_str(data).c_str()); fflush(stdout); common_log_resume(common_log_main()); } @@ -1462,9 +1588,9 @@ void server_models_routes::init_routes() { auto res = std::make_unique(); res_ok(res, { // TODO: add support for this on web UI - {"role", "router"}, - {"max_instances", params.models_max}, - {"models_autoload", params.models_autoload}, + {"role", "router"}, + {"max_instances", params.models_max}, + {"models_autoload", params.models_autoload}, // this is a dummy response to make sure the UI doesn't break {"model_alias", "llama-server"}, {"model_path", "none"}, @@ -1473,11 +1599,9 @@ void server_models_routes::init_routes() { {"n_ctx", 0}, }}, // New key - {"ui_settings", ui_settings}, - // Deprecated: use ui_settings instead (kept for backward compat) - {"webui_settings", webui_settings}, - {"build_info", std::string(llama_build_info())}, - {"cors_proxy_enabled", params.ui_mcp_proxy || params.webui_mcp_proxy}, + {"ui_settings", ui_settings}, + {"build_info", std::string(llama_build_info())}, + {"cors_proxy_enabled", params.ui_mcp_proxy}, }); return res; } @@ -1607,7 +1731,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not found", ERROR_TYPE_INVALID_REQUEST)); return res; } - if (!model->is_running()) { + if (!model->is_running() && model->status != SERVER_MODEL_STATUS_DOWNLOADING) { res_err(res, format_error_response("model is not running", ERROR_TYPE_INVALID_REQUEST)); return res; } @@ -1643,23 +1767,14 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model must be a non-empty string"); } - common_params_model model; - common_download_opts opts; + common_params p; + p.model.hf_repo = name; + p.hf_token = params.hf_token; - model.name = name; - model.hf_repo = name; - opts.bearer_token = params.hf_token; - opts.download_mmproj = true; - opts.download_mtp = true; - - // first, only check if the model is valid and can be downloaded - opts.skip_download = true; + // validate by fetching metadata bool ok = false; try { - auto validation = common_download_model(model, opts); - ok = !validation.model_path.empty(); - } catch (const common_skip_download_exception &) { - // model is valid and will be downloaded + common_models_handler_init(p, LLAMA_EXAMPLE_SERVER); ok = true; } catch (...) { SRV_ERR("unknown error while validating model '%s'\n", name.c_str()); @@ -1671,10 +1786,21 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model validation failed, unable to download"); } + // reject if model already exists + if (models.has_model(name)) { + throw std::invalid_argument("model '" + name + "' already exists"); + } + // then, proceed with the actual download - opts.skip_download = false; SRV_INF("starting download for model '%s'\n", name.c_str()); - models.download(std::move(model), std::move(opts)); + { + server_models::load_options load_opts; + load_opts.mode = SERVER_CHILD_MODE_DOWNLOAD; + load_opts.custom_meta = server_model_meta{}; + load_opts.custom_meta->source = SERVER_MODEL_SOURCE_CACHE; + load_opts.custom_meta->name = name; + models.load(name, load_opts); + } res_ok(res, {{"success", true}}); return res; @@ -1688,10 +1814,7 @@ void server_models_routes::init_routes() { throw std::invalid_argument("model must be a non-empty string"); } - bool ok = models.remove(name); - if (!ok) { - throw std::runtime_error("failed to remove model '" + name + "'"); - } + models.remove(name); // throws on error res_ok(res, {{"success", true}}); return res; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 319c4352e2..9ed4aeead0 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -40,6 +40,11 @@ enum server_model_source { SERVER_MODEL_SOURCE_CACHE, }; +enum server_child_mode { + SERVER_CHILD_MODE_NORMAL, // load the model and run normally + SERVER_CHILD_MODE_DOWNLOAD, // download the model and exit +}; + static std::string server_model_status_to_string(server_model_status status) { switch (status) { case SERVER_MODEL_STATUS_DOWNLOADING: return "downloading"; @@ -72,6 +77,7 @@ struct server_model_meta { int64_t last_used = 0; // for LRU unloading std::vector args; // args passed to the model instance, will be populated by render_args() json loaded_info; // info to be reflected via /v1/models endpoint ; if in DOWNLOADING state, it should contain download progress info + json progress; // reflect load or download progress info, if any int exit_code = 0; // exit code of the model instance process (only valid if status == FAILED) int stop_timeout = 0; // seconds to wait before force-killing the model instance during shutdown mtmd_caps multimodal; // multimodal capabilities @@ -104,7 +110,6 @@ private: std::shared_ptr subproc; // shared between main thread and monitoring thread std::thread th; server_model_meta meta; - FILE * stdin_file = nullptr; }; std::mutex mutex; @@ -160,19 +165,28 @@ public: // return a copy of all model metadata (thread-safe) std::vector get_all_meta(); + struct load_options { + server_child_mode mode = SERVER_CHILD_MODE_NORMAL; + // used for spawning a downloading child process + std::optional custom_meta = std::nullopt; + }; + // load and unload model instances // these functions are thread-safe void load(const std::string & name); + void load(const std::string & name, const load_options & opts); void unload(const std::string & name); void unload_all(); - // download a new model, progress is reported via SSE - // to stop the download, call unload() - void download(common_params_model && model, common_download_opts && opts); - + struct update_status_args { + server_model_status status; + int exit_code = 0; // only valid if status == UNLOADED + json loaded_info = nullptr; + json progress = nullptr; + }; // update the status of a model instance (thread-safe) - void update_status(const std::string & name, server_model_status status, int exit_code); - void update_loaded_info(const std::string & name, std::string & raw_info); + // also send SSE notification to /models/sse endpoint + void update_status(const std::string & name, const update_status_args & args); void update_download_progress(const std::string & name, const common_download_progress & progress, bool done, bool ok = true); // remove a cache model from disk and update the list (thread-safe) @@ -193,34 +207,47 @@ public: // proxy an HTTP request to the model instance server_http_res_ptr proxy_request(const server_http_req & req, const std::string & method, const std::string & name, bool update_last_used); + // handle message sent from server_child::notify_to_router() + // raw input must starts with CMD_CHILD_TO_ROUTER_STATE, followed by a JSON string + // this function is not thread-safe, must be called from instance's monitoring thread + // payload per state: + // state = loading -> payload = {} (TODO: add progress info) + // state = ready -> payload = model_info (json), or {} if wakeup from sleeping + // state = sleeping -> payload = {} + void handle_child_state(const std::string & name, const std::string & raw_input); +}; + +struct server_child { + // serializes the notify_to_router writes + std::mutex mtx_stdout; + std::atomic is_finished_downloading = false; // set by run_download + // return true if the current process is a child server instance - static bool is_child_server(); + bool is_child(); + server_child_mode get_mode(); + int run_download(common_params & params); - // notify the router server that a model instance is ready + // register the shutdown_handler to be called by the router // return the monitoring thread (to be joined by the caller) - static std::thread setup_child_server(const std::function & shutdown_handler, const json & model_info); + std::thread setup(const std::function & shutdown_handler); - // notify the router server that the sleeping state has changed - static void notify_router_sleeping_state(bool sleeping); + // notify router server for status changes (e.g. loading, downloading, sleeping, etc.) + // message will be handled by server_models::handle_child_state() on the router side + void notify_to_router(const std::string & state_name, const json & payload); }; struct server_models_routes { common_params params; json ui_settings = json::object(); // Primary: new name - json webui_settings = json::object(); // Deprecated: use ui_settings (kept for compat) std::atomic stopping = false; // for graceful disconnecting SSE clients during shutdown server_models models; server_models_routes(const common_params & params, int argc, char ** argv) : params(params), models(params, argc, argv) { - // Support both new ui_config_json and deprecated webui_config_json - const std::string & cfg = !this->params.ui_config_json.empty() - ? this->params.ui_config_json - : this->params.webui_config_json; + const std::string & cfg = this->params.ui_config_json; if (!cfg.empty()) { try { json json_settings = json::parse(cfg); ui_settings = json_settings; - webui_settings = json_settings; // Deprecated: keep in sync } catch (const std::exception & e) { LOG_ERR("%s: failed to parse UI config: %s\n", __func__, e.what()); throw; diff --git a/tools/server/server-schema.cpp b/tools/server/server-schema.cpp index d5d747a654..ed4bda2412 100644 --- a/tools/server/server-schema.cpp +++ b/tools/server/server-schema.cpp @@ -14,6 +14,9 @@ std::vector> make_llama_cmpl_schema(const common_params & fields.emplace_back(f); }; + add((new field_bool("verbose", params.verbose)) + ->set_desc("Include __verbose field in the response with additional debug information")); + add((new field_bool("timings_per_token", params.timings_per_token)) ->set_desc("Include prompt processing and text generation speed information in each response")); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 9ba039c8b8..a9ebac013f 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -591,10 +591,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp() { for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { output.push_back(json { + {"id", "fc_" + tool_call.id}, {"type", "function_call"}, {"status", "completed"}, {"arguments", tool_call.arguments}, - {"call_id", "fc_" + tool_call.id}, + {"call_id", "call_" + tool_call.id}, {"name", tool_call.name}, }); } @@ -690,10 +691,11 @@ json server_task_result_cmpl_final::to_json_oaicompat_resp_stream() { for (const common_chat_tool_call & tool_call : oaicompat_msg.tool_calls) { const json output_item = { + {"id", "fc_" + tool_call.id}, {"type", "function_call"}, {"status", "completed"}, {"arguments", tool_call.arguments}, - {"call_id", "fc_" + tool_call.id}, + {"call_id", "call_" + tool_call.id}, {"name", tool_call.name} }; server_sent_events.push_back(json { @@ -1277,8 +1279,9 @@ json server_task_result_cmpl_partial::to_json_oaicompat_resp() { {"data", json { {"type", "response.output_item.added"}, {"item", json { + {"id", "fc_" + diff.tool_call_delta.id}, {"arguments", ""}, - {"call_id", "fc_" + diff.tool_call_delta.id}, + {"call_id", "call_" + diff.tool_call_delta.id}, {"name", diff.tool_call_delta.name}, {"type", "function_call"}, {"status", "in_progress"}, diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 299c279d7d..293bdf053a 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -62,9 +62,6 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; - int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -92,6 +89,9 @@ struct task_params { // per-request parameters for chat parsing common_chat_parser_params chat_parser_params; + // message spans for checkpointing + common_chat_msg_spans message_spans; + // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) diff --git a/tools/server/server-tools.cpp b/tools/server/server-tools.cpp index 97433fe4b5..790ed85a06 100644 --- a/tools/server/server-tools.cpp +++ b/tools/server/server-tools.cpp @@ -11,6 +11,7 @@ #include #include #include +#include namespace fs = std::filesystem; @@ -568,9 +569,13 @@ struct server_tool_edit_file : server_tool { } int n = (int) lines.size(); if (e.line_start == -1) { - // -1 means end of file; line_end is ignored โ€” normalize to point past last line - e.line_start = n + 1; - e.line_end = n + 1; + // -1 targets end of file -> valid for append only; line_end is ignored + if (e.mode != "append") { + return {{"error", "line_start -1 (end of file) is only valid for append mode"}}; + } + // append at end of file: insert position is the current line count + e.line_start = n; + e.line_end = n; } else { if (e.line_start < 1 || e.line_end < e.line_start) { return {{"error", string_format("invalid line range [%d, %d]", e.line_start, e.line_end)}}; @@ -611,8 +616,8 @@ struct server_tool_edit_file : server_tool { } else if (e.mode == "delete") { lines.erase(lines.begin() + idx_start, lines.begin() + idx_end + 1); } else { // append - // idx_end + 1 may equal lines.size() when line_start == -1 (end of file) - lines.insert(lines.begin() + idx_end + 1, new_lines.begin(), new_lines.end()); + // insert after idx_end; idx_end + 1 == lines.size() for end-of-file append + lines.insert(lines.begin() + (idx_end + 1), new_lines.begin(), new_lines.end()); } } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 78ab0318cf..680590871f 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -89,9 +89,23 @@ int llama_server(int argc, char ** argv) { llama_backend_init(); llama_numa_init(params.numa); + common_models_handler models_handler; + try { + models_handler = common_models_handler_init(params, LLAMA_EXAMPLE_SERVER); + if (common_models_handler_is_preset_repo(models_handler)) { + // apply the preset and start the server in router mode + common_models_handler_apply(models_handler, params); + } + } catch (const std::exception & e) { + SRV_ERR("failed to fetch model metadata: %s\n", e.what()); + return 1; + } + // router server never loads a model and must not touch the GPU + const bool is_router_server = params.model.path.empty() + && params.model.hf_repo.empty(); + // skip device enumeration so the CUDA primary context stays uncreated - const bool is_router_server = params.model.path.empty(); common_params_print_info(params, !is_router_server); if (!is_router_server) { @@ -113,8 +127,9 @@ int llama_server(int argc, char ** argv) { } // for consistency between server router mode and single-model mode, we set the same model name as alias - if (params.model_alias.empty() && !params.model.name.empty()) { - params.model_alias.insert(params.model.name); + auto model_name = params.model.get_name(); + if (params.model_alias.empty() && !model_name.empty()) { + params.model_alias.insert(model_name); } // struct that contains llama context and inference @@ -131,6 +146,7 @@ int llama_server(int argc, char ** argv) { // // register API routes + server_child child; // only used in non-router mode server_routes routes(params, ctx_server); server_tools tools; @@ -227,8 +243,7 @@ int llama_server(int argc, char ** argv) { ctx_http.register_gcp_compat(); // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) - // Supports both new ui_mcp_proxy and deprecated webui_mcp_proxy fields - if (params.ui_mcp_proxy || params.webui_mcp_proxy) { + if (params.ui_mcp_proxy) { SRV_WRN("%s", "-----------------\n"); SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n"); SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n"); @@ -252,6 +267,22 @@ int llama_server(int argc, char ** argv) { ctx_http.post("/tools", ex_wrapper(tools.handle_post)); } + // + // Handle downloading model + // + + if (child.is_child() && child.get_mode() == SERVER_CHILD_MODE_DOWNLOAD) { + return child.run_download(params); + } else if (!is_router_server) { + // single-model mode (NOT spawned by router) + try { + common_models_handler_apply(models_handler, params); + } catch (const std::exception & e) { + SRV_ERR("failed to download model: %s\n", e.what()); + return 1; + } + } + // // Start the server // @@ -301,15 +332,16 @@ int llama_server(int argc, char ** argv) { return 1; } - // load the model - SRV_INF("%s", "loading model\n"); - - if (server_models::is_child_server()) { - ctx_server.on_sleeping_changed([&](bool sleeping) { - server_models::notify_router_sleeping_state(sleeping); + // setup communication child --> router if necessary + if (child.is_child()) { + ctx_server.set_state_callback([&](server_state state, json payload) { + child.notify_to_router(server_state_to_str(state), payload); }); } + // load the model + SRV_INF("%s", "loading model\n"); + if (!ctx_server.load_model(params)) { clean_up(); if (ctx_http.thread.joinable()) { @@ -366,9 +398,9 @@ int llama_server(int argc, char ** argv) { // optionally, notify router server that this instance is ready std::thread monitor_thread; - if (server_models::is_child_server()) { - json model_info = routes.get_model_info(); - monitor_thread = server_models::setup_child_server(shutdown_handler, model_info); + if (child.is_child()) { + monitor_thread = child.setup(shutdown_handler); + child.notify_to_router(server_state_to_str(SERVER_STATE_READY), routes.get_model_info()); } // this call blocks the main thread until queue_tasks.terminate() is called diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index d1b89cf1a9..285726abf4 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -79,9 +79,9 @@ def test_load_split_model(): assert match_regex("(little|girl)+", res.body["content"]) -def test_no_webui(): +def test_no_ui(): global server - # default: webui enabled + # default: UI enabled server.start() url = f"http://{server.server_host}:{server.server_port}" res = requests.get(url) @@ -89,8 +89,8 @@ def test_no_webui(): assert "" in res.text server.stop() - # with --no-webui - server.no_webui = True + # with --no-ui, the UI should be disabled + server.no_ui = True server.start() res = requests.get(url) assert res.status_code == 404 diff --git a/tools/server/tests/unit/test_chat_completion.py b/tools/server/tests/unit/test_chat_completion.py index b00aac649d..0258b539ed 100644 --- a/tools/server/tests/unit/test_chat_completion.py +++ b/tools/server/tests/unit/test_chat_completion.py @@ -603,3 +603,23 @@ def test_chat_completions_token_count(): }) assert res.status_code == 200 assert res.body["input_tokens"] > 5 + + +def test_verbose_debug(): + global server + server.start() + for verbose in [True, False]: + res = server.make_request("POST", "/chat/completions", data={ + "max_tokens": 2, + "messages": [ + {"role": "system", "content": "Book"}, + {"role": "user", "content": "What is the best book"}, + ], + "verbose": verbose, + }) + assert res.status_code == 200 + if verbose: + assert "__verbose" in res.body + assert "Book" in res.body["__verbose"]["prompt"] + else: + assert "__verbose" not in res.body diff --git a/tools/server/tests/unit/test_proxy.py b/tools/server/tests/unit/test_proxy.py index b7c3326187..3b86d80473 100644 --- a/tools/server/tests/unit/test_proxy.py +++ b/tools/server/tests/unit/test_proxy.py @@ -12,7 +12,7 @@ def create_server(): def test_mcp_no_proxy(): global server - server.webui_mcp_proxy = False + server.ui_mcp_proxy = False server.start() res = server.make_request("GET", "/cors-proxy") @@ -21,7 +21,7 @@ def test_mcp_no_proxy(): def test_mcp_proxy(): global server - server.webui_mcp_proxy = True + server.ui_mcp_proxy = True server.start() url = f"http://{server.server_host}:{server.server_port}/cors-proxy?url=http://example.com" @@ -32,7 +32,7 @@ def test_mcp_proxy(): def test_mcp_proxy_custom_port(): global server - server.webui_mcp_proxy = True + server.ui_mcp_proxy = True server.start() # try getting the server's models API via the proxy diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 11c77ca7aa..94165e520e 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -256,15 +256,45 @@ def test_router_reload_models(): os.remove(preset_path) +def test_router_remote_preset(): + global server + server.model_hf_repo = "ggml-org/test-preset-ci" + server.model_hf_file = None + server.offline = False + server.start() + + # Should see preset models in GET /models + res = server.make_request("GET", "/models") + assert res.status_code == 200 + ids = {item["id"] for item in res.body.get("data", [])} + assert "tinygemma3-preset" in ids + assert "stories260K-test" in ids + + # Should be able to load a preset model + model_id = "tinygemma3-preset" + _load_model_and_wait(model_id) + + MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" -MODEL_DOWNLOAD_TIMEOUT = 300 +MODEL_DOWNLOAD_TIMEOUT = 30 -def _listen_sse(server: ServerProcess, collected: list, stop: threading.Event): - """Collect /models/sse events into `collected` until `stop` is set.""" +def _listen_sse( + server: ServerProcess, collected: list, stop: threading.Event, ready: threading.Event | None = None +): + """Collect /models/sse events into `collected` until `stop` is set. + + When `ready` is provided, it is set once the streaming response is open, + i.e. the server has accepted the connection and registered us as a + subscriber. Callers that trigger one-shot events (e.g. download_finished) + must wait on `ready` before acting, otherwise the event can be broadcast + before this client is subscribed and be lost. + """ url = f"http://{server.server_host}:{server.server_port}/models/sse" try: with requests.get(url, stream=True, timeout=MODEL_DOWNLOAD_TIMEOUT) as resp: + if ready is not None: + ready.set() for line_bytes in resp.iter_lines(): if stop.is_set(): break @@ -294,11 +324,17 @@ def test_router_download_model(): sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() sse_thread = threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ) sse_thread.start() + # wait for the SSE client to be subscribed before triggering the download, + # otherwise the one-shot download_finished event can be broadcast before + # this client is registered and be lost + assert sse_ready.wait(10), "SSE client failed to connect" + # Trigger the download res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) assert res.status_code == 200 @@ -328,13 +364,17 @@ def test_router_delete_model(): # Ensure the model exists (download it if needed) if MODEL_DOWNLOAD_ID not in _get_model_ids(is_reload=False): - res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) - assert res.status_code == 200 sse_events: list = [] stop = threading.Event() + sse_ready = threading.Event() threading.Thread( - target=_listen_sse, args=(server, sse_events, stop), daemon=True + target=_listen_sse, args=(server, sse_events, stop, sse_ready), daemon=True ).start() + # subscribe before triggering the download so the one-shot + # download_finished event is not lost (see test_router_download_model) + assert sse_ready.wait(10), "SSE client failed to connect" + res = server.make_request("POST", "/models", data={"model": MODEL_DOWNLOAD_ID}) + assert res.status_code == 200 finished = _wait_for_sse_event( sse_events, "download_finished", MODEL_DOWNLOAD_ID, MODEL_DOWNLOAD_TIMEOUT ) diff --git a/tools/server/tests/unit/test_security.py b/tools/server/tests/unit/test_security.py index 02d0b1afbc..a0c3e214ae 100644 --- a/tools/server/tests/unit/test_security.py +++ b/tools/server/tests/unit/test_security.py @@ -1,6 +1,8 @@ import pytest from openai import OpenAI from utils import * +import threading +from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer server = ServerPreset.tinyllama2() @@ -105,6 +107,49 @@ def test_cors_options(origin: str, cors_header: str, cors_header_value: str): assert res.headers[cors_header] == cors_header_value +def test_cors_proxy_only_forwards_explicit_proxy_headers(): + class CaptureHeadersHandler(BaseHTTPRequestHandler): + def do_GET(self): + self.server.captured_headers = dict(self.headers) + self.send_response(200) + self.end_headers() + self.wfile.write(b"ok") + + def log_message(self, format, *args): + pass + + target = ThreadingHTTPServer(("127.0.0.1", 0), CaptureHeadersHandler) + target.captured_headers = {} + target_thread = threading.Thread(target=target.serve_forever, daemon=True) + target_thread.start() + + try: + server = ServerPreset.tinyllama2() + server.api_key = TEST_API_KEY + server.ui_mcp_proxy = True + server.start() + + res = server.make_request("GET", f"/cors-proxy?url=http://127.0.0.1:{target.server_port}/capture", headers={ + "Authorization": f"Bearer {TEST_API_KEY}", + "Proxy-Authorization": "Basic secret", + "X-Api-Key": TEST_API_KEY, + "Cookie": "session=secret", + "x-llama-server-proxy-header-accept": "application/json", + "x-llama-server-proxy-header-authorization": "Bearer explicit", + }) + + assert res.status_code == 200 + captured = {key.lower(): value for key, value in target.captured_headers.items()} + assert captured["accept"] == "application/json" + assert captured["authorization"] == "Bearer explicit" + assert "proxy-authorization" not in captured + assert "x-api-key" not in captured + assert "cookie" not in captured + finally: + target.shutdown() + target.server_close() + + @pytest.mark.parametrize( "media_path, image_url, success", [ diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index c50c9a0f5a..63a959449e 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -94,7 +94,7 @@ class ServerProcess: enable_ctx_shift: int | None = False spec_draft_n_min: int | None = None spec_draft_n_max: int | None = None - no_webui: bool | None = None + no_ui: bool | None = None jinja: bool | None = None reasoning_format: Literal['deepseek', 'none', 'nothink'] | None = None reasoning: Literal['on', 'off', 'auto'] | None = None @@ -107,7 +107,7 @@ class ServerProcess: cache_ram: int | None = None no_cache_idle_slots: bool = False log_path: str | None = None - webui_mcp_proxy: bool = False + ui_mcp_proxy: bool = False backend_sampling: bool = False gcp_compat: bool = False @@ -225,8 +225,8 @@ class ServerProcess: server_args.extend(["--spec-draft-n-max", self.spec_draft_n_max]) if self.spec_draft_n_min: server_args.extend(["--spec-draft-n-min", self.spec_draft_n_min]) - if self.no_webui: - server_args.append("--no-webui") + if self.no_ui: + server_args.append("--no-ui") if self.no_models_autoload: server_args.append("--no-models-autoload") if self.jinja: @@ -251,8 +251,8 @@ class ServerProcess: server_args.extend(["--cache-ram", self.cache_ram]) if self.no_cache_idle_slots: server_args.append("--no-cache-idle-slots") - if self.webui_mcp_proxy: - server_args.append("--webui-mcp-proxy") + if self.ui_mcp_proxy: + server_args.append("--ui-mcp-proxy") if self.backend_sampling: server_args.append("--backend_sampling") if self.gcp_compat: diff --git a/tools/ui/.gitignore b/tools/ui/.gitignore index ddcfe2e60f..0bb8c9b3c2 100644 --- a/tools/ui/.gitignore +++ b/tools/ui/.gitignore @@ -28,10 +28,9 @@ vite.config.ts.timestamp-* # PWA Artifacts apple-splash-*.png apple-touch-icon-*.png -favicon.ico -favicon-dark.ico maskable-icon-*.png pwa-*.png +static/favicon* # Storybook *storybook.log diff --git a/tools/ui/package-lock.json b/tools/ui/package-lock.json index 9d0cdfea6c..9dce3a0c9d 100644 --- a/tools/ui/package-lock.json +++ b/tools/ui/package-lock.json @@ -35,7 +35,7 @@ "bits-ui": "2.18.1", "clsx": "2.1.1", "dexie": "4.4.3", - "dompurify": "3.4.5", + "dompurify": "3.4.11", "eslint": "9.39.4", "eslint-config-prettier": "10.1.8", "eslint-plugin-storybook": "10.4.2", @@ -8653,9 +8653,9 @@ "peer": true }, "node_modules/dompurify": { - "version": "3.4.5", - "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.5.tgz", - "integrity": "sha512-OrwIBKsdNSVEeubdJ1HBv/wNENRM9ytAVCv7YXt//A3vPdVMNuACRqK9mXCGCBW2ln7BT/A4X0jXHo2Gu89miA==", + "version": "3.4.11", + "resolved": "https://registry.npmjs.org/dompurify/-/dompurify-3.4.11.tgz", + "integrity": "sha512-zhlUV12GsaRzMsf9q5M254YhA4+VuF0fG+QFqu6aYpoGlKtz+w8//jBcGVYBgQkR5GHjUomejY84AV+/uPbWdw==", "dev": true, "license": "(MPL-2.0 OR Apache-2.0)", "optionalDependencies": { @@ -10226,9 +10226,9 @@ } }, "node_modules/hono": { - "version": "4.12.23", - "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.23.tgz", - "integrity": "sha512-eIaZ9qDgu7XV0pxOCrg7/WhnQ6Ivm22UcxhXx/A3dcbqbbYgBEkc6e/J/s7j2tS96zoB0S9VBdLwQNCWwUo4LA==", + "version": "4.12.26", + "resolved": "https://registry.npmjs.org/hono/-/hono-4.12.26.tgz", + "integrity": "sha512-uyZtpnYxM9CmQ7QsQknM4zN8EftNqhON1qYeIKM0Se67CCEe2c44xyGURwB0axX2fBDu1dqHrHAc1hmNT8ITkw==", "dev": true, "license": "MIT", "engines": { diff --git a/tools/ui/package.json b/tools/ui/package.json index 4803922889..bcb4165d10 100644 --- a/tools/ui/package.json +++ b/tools/ui/package.json @@ -54,7 +54,7 @@ "bits-ui": "2.18.1", "clsx": "2.1.1", "dexie": "4.4.3", - "dompurify": "3.4.5", + "dompurify": "3.4.11", "eslint": "9.39.4", "eslint-config-prettier": "10.1.8", "eslint-plugin-storybook": "10.4.2", diff --git a/tools/ui/pwa-assets-dark.config.ts b/tools/ui/pwa-assets-dark.config.ts index e9793bca9e..358c0ebc07 100644 --- a/tools/ui/pwa-assets-dark.config.ts +++ b/tools/ui/pwa-assets-dark.config.ts @@ -1,4 +1,10 @@ import { defineConfig } from '@vite-pwa/assets-generator/config'; +import { FAVICON_COLORS, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa'; +import { writeThemeFavicons } from './scripts/favicon-colorize'; + +writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, { + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING +}); export default defineConfig({ headLinkOptions: { @@ -7,7 +13,8 @@ export default defineConfig({ preset: { transparent: { sizes: [], - favicons: [[48, 'favicon-dark.ico']] + favicons: [[48, 'favicon-dark.ico']], + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING }, maskable: { sizes: [] diff --git a/tools/ui/pwa-assets.config.ts b/tools/ui/pwa-assets.config.ts index 54928eeb4a..b69884d94a 100644 --- a/tools/ui/pwa-assets.config.ts +++ b/tools/ui/pwa-assets.config.ts @@ -5,15 +5,32 @@ import { } from '@vite-pwa/assets-generator/config'; import { readFileSync } from 'node:fs'; import { resolve } from 'node:path'; -import { THEME_COLORS, PWA_GENERATOR_DEVICES, PWA_ASSET_GENERATOR } from './src/lib/constants/pwa'; +import { + THEME_COLORS, + PWA_GENERATOR_DEVICES, + PWA_ASSET_GENERATOR, + FAVICON_COLORS +} from './src/lib/constants/pwa'; import { SplashOrientation } from './src/lib/enums/splash.enums'; +import { writeThemeFavicons } from './scripts/favicon-colorize'; + +writeThemeFavicons(FAVICON_COLORS.LIGHT, FAVICON_COLORS.DARK, { + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING +}); export default defineConfig({ headLinkOptions: { preset: PWA_ASSET_GENERATOR.LINK_PRESET }, preset: combinePresetAndAppleSplashScreens( - minimal2023Preset, + { + ...minimal2023Preset, + // tiny margin so favicon.ico / pwa-*.png breathe inside the canvas + transparent: { + ...minimal2023Preset.transparent, + padding: PWA_ASSET_GENERATOR.FAVICON_PADDING + } + }, { padding: PWA_ASSET_GENERATOR.SPLASH_PADDING, resizeOptions: { diff --git a/tools/ui/scripts/favicon-colorize.ts b/tools/ui/scripts/favicon-colorize.ts new file mode 100644 index 0000000000..e1872b7774 --- /dev/null +++ b/tools/ui/scripts/favicon-colorize.ts @@ -0,0 +1,107 @@ +import { mkdirSync, readFileSync, writeFileSync } from 'node:fs'; +import { dirname, resolve } from 'node:path'; +import { fileURLToPath } from 'node:url'; + +const HERE = dirname(fileURLToPath(import.meta.url)); +const PROJECT_ROOT = resolve(HERE, '..'); + +const DEFAULT_LOGO = resolve(PROJECT_ROOT, 'src/lib/assets/logo.svg'); +const DEFAULT_OUT_DIR = resolve(PROJECT_ROOT, 'static'); +const DEFAULT_OUT_LIGHT = resolve(DEFAULT_OUT_DIR, 'favicon.svg'); +const DEFAULT_OUT_DARK = resolve(DEFAULT_OUT_DIR, 'favicon-dark.svg'); + +const CURRENT_COLOR = 'currentColor'; + +export interface ColorizedFavicon { + light: string; + dark: string; +} + +export interface WriteThemeFaviconsOptions { + sourcePath?: string; + lightOutPath?: string; + darkOutPath?: string; + /** + * Fraction of the icon (0..1) to leave as an even margin on each side. + * Applied by wrapping the inner content in a `` so the + * source `src/lib/assets/logo.svg` is not modified. Pass 0 to disable. + */ + padding?: number; +} + +/** + * Replace every `currentColor` occurrence in the SVG with the given color. + * Pure: no filesystem access, so it is straightforward to unit-test. + */ +export function colorizeFaviconSvg( + svg: string, + lightColor: string, + darkColor: string +): ColorizedFavicon { + return { + light: svg.replaceAll(CURRENT_COLOR, lightColor), + dark: svg.replaceAll(CURRENT_COLOR, darkColor) + }; +} + +/** + * Shrink the inner SVG content uniformly and re-center it so `padding` (a + * 0..1 fraction) is reserved as equal margin on each side. Returns the input + * unchanged for non-positive padding, missing/invalid `viewBox`, or unexpected + * markup so the caller always gets a renderable SVG. + */ +export function padFaviconSvg(svg: string, padding: number): string { + if (!(padding > 0) || padding >= 1) return svg; + + const viewBoxMatch = svg.match(/viewBox\s*=\s*["']([^"']+)["']/i); + if (!viewBoxMatch) return svg; + + const parts = viewBoxMatch[1] + .trim() + .split(/[\s,]+/) + .map(Number); + if (parts.length !== 4 || parts.some((n) => !Number.isFinite(n))) return svg; + + const [, , width, height] = parts; + if (width <= 0 || height <= 0) return svg; + + const scale = 1 - padding; + const translateX = (padding * width) / 2; + const translateY = (padding * height) / 2; + + const openTagStart = svg.search(/', openTagStart); + if (openTagEnd === -1) return svg; + const closeStart = svg.lastIndexOf('`; + return `${openTag}${group}${inner}${closeTag}`; +} + +/** + * Read `src/lib/assets/logo.svg`, colorize it for both themes, and write + * the results to the static directory so the PWA asset generator can consume + * them. Paths can be overridden for tests. + */ +export function writeThemeFavicons( + lightColor: string, + darkColor: string, + { + sourcePath = DEFAULT_LOGO, + lightOutPath = DEFAULT_OUT_LIGHT, + darkOutPath = DEFAULT_OUT_DARK, + padding = 0 + }: WriteThemeFaviconsOptions = {} +): void { + const source = readFileSync(sourcePath, 'utf-8'); + const { light, dark } = colorizeFaviconSvg(source, lightColor, darkColor); + mkdirSync(dirname(lightOutPath), { recursive: true }); + writeFileSync(lightOutPath, padFaviconSvg(light, padding)); + writeFileSync(darkOutPath, padFaviconSvg(dark, padding)); +} diff --git a/tools/ui/src/app.css b/tools/ui/src/app.css index 9254a96df7..8c4056477d 100644 --- a/tools/ui/src/app.css +++ b/tools/ui/src/app.css @@ -48,6 +48,7 @@ --chat-form-area-height: 8rem; --chat-form-area-offset: 2rem; + --chat-form-padding-top: 6rem; --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); } @@ -55,6 +56,7 @@ :root { --chat-form-area-height: 24rem; --chat-form-area-offset: 12rem; + --chat-form-padding-top: 6rem; } } @@ -141,7 +143,6 @@ @apply bg-background text-foreground; scrollbar-width: thin; scrollbar-gutter: stable; - overflow: hidden; /* Added due to Mermaid rendering somehow causing the double scrollbar */ } /* Global scrollbar styling - visible only on hover */ @@ -193,3 +194,7 @@ scrollbar-width: none; } } + +.mermaidTooltip { + display: none !important; +} diff --git a/tools/ui/src/app.d.ts b/tools/ui/src/app.d.ts index a7583eec59..5264e5cc4d 100644 --- a/tools/ui/src/app.d.ts +++ b/tools/ui/src/app.d.ts @@ -19,6 +19,10 @@ import type { ApiErrorResponse, ApiLlamaCppServerProps, ApiModelDataEntry, + ApiModelLoadStage, + ApiModelsSseProgress, + ApiModelsSseData, + ApiModelsSseEvent, ApiModelListResponse, ApiProcessingState, ApiRouterModelMeta, @@ -52,6 +56,7 @@ import type { // Model types ModelModalities, ModelOption, + ModelLoadProgress, // Settings types SettingsChatServiceOptions, SettingsConfigValue, @@ -83,6 +88,10 @@ declare global { ApiErrorResponse, ApiLlamaCppServerProps, ApiModelDataEntry, + ApiModelLoadStage, + ApiModelsSseProgress, + ApiModelsSseData, + ApiModelsSseEvent, ApiModelListResponse, ApiProcessingState, ApiRouterModelMeta, @@ -120,6 +129,7 @@ declare global { // Model types ModelModalities, ModelOption, + ModelLoadProgress, // Settings types SettingsChatServiceOptions, SettingsConfigValue, diff --git a/tools/ui/src/lib/actions/fade-in-view.svelte.ts b/tools/ui/src/lib/actions/fade-in-view.svelte.ts index d930448050..9a5918131a 100644 --- a/tools/ui/src/lib/actions/fade-in-view.svelte.ts +++ b/tools/ui/src/lib/actions/fade-in-view.svelte.ts @@ -10,9 +10,9 @@ import { isElementInViewport } from '$lib/utils/viewport'; */ export function fadeInView( node: HTMLElement, - options: { duration?: number; y?: number; skipIfVisible?: boolean } = {} + options: { duration?: number; y?: number; delay?: number; skipIfVisible?: boolean } = {} ) { - const { duration = 300, y = 0, skipIfVisible = false } = options; + const { duration = 300, y = 0, delay = 0, skipIfVisible = false } = options; if (skipIfVisible && isElementInViewport(node)) { return; @@ -27,10 +27,12 @@ export function fadeInView( (entries) => { for (const entry of entries) { if (entry.isIntersecting) { - requestAnimationFrame(() => { - node.style.opacity = '1'; - node.style.transform = 'translateY(0)'; - }); + setTimeout(() => { + requestAnimationFrame(() => { + node.style.opacity = '1'; + node.style.transform = 'translateY(0)'; + }); + }, delay); observer.disconnect(); } } diff --git a/tools/ui/src/lib/assets/logo.svg b/tools/ui/src/lib/assets/logo.svg new file mode 100644 index 0000000000..05424790af --- /dev/null +++ b/tools/ui/src/lib/assets/logo.svg @@ -0,0 +1,7 @@ + + + + + + + diff --git a/tools/ui/src/lib/components/app/actions/ActionIcon.svelte b/tools/ui/src/lib/components/app/actions/ActionIcon.svelte index f156df6699..8a86557bb9 100644 --- a/tools/ui/src/lib/components/app/actions/ActionIcon.svelte +++ b/tools/ui/src/lib/components/app/actions/ActionIcon.svelte @@ -8,12 +8,13 @@ ariaLabel?: string; class?: string; disabled?: boolean; + href?: string; icon: Component; iconSize?: string; - onclick: (e?: MouseEvent) => void; + onclick?: (e?: MouseEvent) => void; size?: ButtonSize; stopPropagationOnClick?: boolean; - tooltip: string; + tooltip?: string; variant?: ButtonVariant; tooltipSide?: TooltipSide; } @@ -22,6 +23,7 @@ icon, tooltip, variant = 'ghost', + href = '', size = 'sm', class: className = '', disabled = false, @@ -31,34 +33,49 @@ onclick, ariaLabel }: Props = $props(); + + let innerWidth = $state(0); + const showTooltip = $derived(!!tooltip && innerWidth > 768); - - - - {#snippet child({ props })} - - {/snippet} - + onclick?.(e); + }} + class="h-6 w-6 p-0 {className} flex hover:bg-transparent data-[state=open]:bg-transparent!" + aria-label={ariaLabel || tooltip} + > + {#if icon} + {@const IconComponent = icon} - -

{tooltip}

-
-
+ + {/if} + +{/snippet} + +{#if showTooltip} + + + + {#snippet child({ props })} + {@render button(props)} + {/snippet} + + + +

{tooltip}

+
+
+{:else} + {@render button({ href })} +{/if} + + diff --git a/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index ed26f9ea58..9b2077b8dc 100644 --- a/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/ui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -494,7 +494,7 @@ />
{#if hasMcpPromptsSupport} - diff --git a/tools/ui/src/lib/components/app/navigation/DesktopIconStrip.svelte b/tools/ui/src/lib/components/app/navigation/DesktopIconStrip.svelte deleted file mode 100644 index e92b9528a6..0000000000 --- a/tools/ui/src/lib/components/app/navigation/DesktopIconStrip.svelte +++ /dev/null @@ -1,84 +0,0 @@ - - - - - - diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigation.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigation.svelte index 1fa7722f24..fe503f53ba 100644 --- a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigation.svelte +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigation.svelte @@ -1,40 +1,76 @@ -
- - -
- -

- {APP_NAME} -

-
+ - +{#if innerWidth > 768 || (!page.url.hash.includes(ROUTES.SETTINGS) && !page.url.hash.includes(ROUTES.MCP_SERVERS) && !page.url.hash.includes(ROUTES.SEARCH))} + +{/if} + + diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte index f0d63970ee..f118a68dfc 100644 --- a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationActions.svelte @@ -1,39 +1,86 @@ @@ -41,56 +88,109 @@ {/snippet} -
- {#if isSearchModeActive} +{#if isSearchModeActive} +
e.key === 'Escape' && handleSearchModeDeactivate()} placeholder="Search conversations..." - {isCancelAlwaysVisible} /> - {:else} - {#each SIDEBAR_ACTIONS_ITEMS as item (item.route)} - {#if !item.route} - - {:else} - + {#if item.keys} + + {/if} + +
{/if} {/each} - {/if} -
+
+{:else} + +{/if} diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte new file mode 100644 index 0000000000..488e96bbcf --- /dev/null +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationConversationList.svelte @@ -0,0 +1,135 @@ + + +{#if isSearchModeActive} + +{:else} + {#if pinnedConversations.length > 0} +
+
+ + + Pinned +
+
+ +
    + {#each pinnedConversations as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} +
+ {/if} + +
+ {#if filteredConversations.length > 0} +
+ Recent conversations +
+ {/if} + +
+
    + {#each unpinnedConversations as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} + + {#if unpinnedConversations.length === 0} +
  • +

    + {recentEmptyMessage} +

    +
  • + {/if} +
+
+
+{/if} diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte index afc9847028..491e7c3479 100644 --- a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearch.svelte @@ -16,4 +16,6 @@ }: Props = $props(); - +
+ +
diff --git a/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte new file mode 100644 index 0000000000..92d8fd0bda --- /dev/null +++ b/tools/ui/src/lib/components/app/navigation/SidebarNavigation/SidebarNavigationSearchResults.svelte @@ -0,0 +1,76 @@ + + +
+ {#if showHeader} +
+ Search results +
+ {/if} + +
+
    + {#each tree as { conversation, depth } (conversation.id)} +
  • + +
  • + {/each} + + {#if tree.length === 0} +
  • +

    + {emptyMessage} +

    +
  • + {/if} +
+
+
diff --git a/tools/ui/src/lib/components/app/navigation/index.ts b/tools/ui/src/lib/components/app/navigation/index.ts index d4ca914594..e07dde6bc9 100644 --- a/tools/ui/src/lib/components/app/navigation/index.ts +++ b/tools/ui/src/lib/components/app/navigation/index.ts @@ -63,15 +63,6 @@ export { default as DropdownMenuSearchable } from './DropdownMenuSearchable.svel * ``` */ export { default as DropdownMenuActions } from './DropdownMenuActions.svelte'; - -/** - * **DesktopIconStrip** - Fixed icon strip for desktop sidebar - * - * Vertical icon strip shown on desktop when the sidebar is collapsed. - * Contains navigation shortcuts for new chat, search, MCP, import/export, and settings. - */ -export { default as DesktopIconStrip } from './DesktopIconStrip.svelte'; - /** * **SidebarNavigation** - Sidebar with actions menu and conversation list * @@ -115,13 +106,6 @@ export { default as DesktopIconStrip } from './DesktopIconStrip.svelte'; */ export { default as SidebarNavigation } from './SidebarNavigation/SidebarNavigation.svelte'; -/** - * Action buttons for sidebar header. Contains new chat button, settings button, - * and delete all conversations button. Manages dialog states for settings and - * delete confirmation. - */ -export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte'; - /** * Single conversation item in sidebar. Displays conversation title (truncated), * last message preview, and timestamp. Shows context menu on right-click with @@ -130,6 +114,58 @@ export { default as SidebarNavigationActions } from './SidebarNavigation/Sidebar */ export { default as SidebarNavigationConversationItem } from './SidebarNavigation/SidebarNavigationConversationItem.svelte'; +/** + * **SidebarNavigationConversationList** - Grouped conversation list + * + * Pure-presentational list of conversations. Splits items into a Pinned + * section (when not in search mode) and a Recent Conversations / Search + * Results section with the unpinned items. Item selection, edit, delete, + * and stop-generation are delegated to the caller via callbacks. + * + * @example + * ```svelte + * + * ``` + */ +export { default as SidebarNavigationConversationList } from './SidebarNavigation/SidebarNavigationConversationList.svelte'; +export { default as SidebarNavigationActions } from './SidebarNavigation/SidebarNavigationActions.svelte'; + +/** + * **SidebarNavigationSearchResults** - Filtered conversation list for search. + * + * Pure-presentational rendering of the search-mode subtree: "Search results" + * header, the matching items rendered through {@link SidebarNavigationConversationItem}, + * and contextual empty-state messages. Used both inline inside + * {@link SidebarNavigationConversationList} (when search mode is active in the + * sidebar) and as the body of the mobile `/search` route. + * + * The caller is expected to provide an already-filtered list via + * `filteredConversations` and a `searchQuery` for the empty-state messages. + * + * @example + * ```svelte + * + * ``` + */ +export { default as SidebarNavigationSearchResults } from './SidebarNavigation/SidebarNavigationSearchResults.svelte'; + /** * Search input for filtering conversations in sidebar. Filters conversation * list by title as user types. Shows clear button when query is not empty. diff --git a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte index 62e73d8579..41105baa5e 100644 --- a/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte +++ b/tools/ui/src/lib/components/app/settings/SettingsChat/SettingsChat.svelte @@ -126,10 +126,7 @@ }); -
+
-