diff --git a/CODEOWNERS b/CODEOWNERS index 4b9d901771..46fd518b7e 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -10,7 +10,7 @@ # ggml-org/ggml-rpc : rgerganov # ggml-org/ggml-sycl : arthw # ggml-org/ggml-vulkan : 0cc4m, jeffbolznv -# ggml-org/ggml-webgpu : reeselevine +# ggml-org/ggml-webgpu : reeselevine, yomaytk # ggml-org/ggml-zdnn : taronaeo # ggml-org/llama-common : ggerganov, aldehir, angt, danbev, ngxson, pwilkin # ggml-org/llama-mtmd : ngxson diff --git a/common/arg.cpp b/common/arg.cpp index 5297d90753..276dbec8ba 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -301,6 +301,8 @@ static handle_model_result common_params_handle_model(struct common_params_model const common_download_opts & opts) { handle_model_result result; + // TODO @ngxson : refactor this into a new common_model_download_context + if (!model.docker_repo.empty()) { model.path = common_docker_resolve_model(model.docker_repo); } else if (!model.hf_repo.empty()) { @@ -396,7 +398,7 @@ static bool parse_bool_value(const std::string & value) { // CLI argument parsing functions // -bool common_params_handle_models(common_params & params, llama_example curr_ex, common_download_callback * callback) { +bool common_params_handle_models(common_params & params, llama_example curr_ex, const common_params_handle_models_params & handle_params) { const bool spec_type_draft_mtp = std::find(params.speculative.types.begin(), params.speculative.types.end(), COMMON_SPECULATIVE_TYPE_DRAFT_MTP) != params.speculative.types.end(); @@ -407,9 +409,10 @@ bool common_params_handle_models(common_params & params, llama_example curr_ex, opts.skip_download = params.skip_download; opts.download_mtp = spec_type_draft_mtp; opts.download_mmproj = !params.no_mmproj && params.mmproj.path.empty() && params.mmproj.url.empty(); + opts.preset_only = handle_params.preset_only; - if (callback) { - opts.callback = callback; + if (handle_params.callback) { + opts.callback = handle_params.callback; } // sub-models (draft, mmproj, vocoder) are explicitly specified by the user, @@ -596,7 +599,7 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context if (!skip_model_download) { // handle model and download - common_params_handle_models(params, ctx_arg.ex); + common_params_handle_models(params, ctx_arg.ex, {}); // model is required (except for server) // TODO @ngxson : maybe show a list of available models in CLI in this case diff --git a/common/arg.h b/common/arg.h index c061fc60f7..fdfc04bc7a 100644 --- a/common/arg.h +++ b/common/arg.h @@ -130,6 +130,11 @@ bool common_params_to_map(int argc, char ** argv, llama_example ex, std::map & args); +struct common_params_handle_models_params { + common_download_callback * callback = nullptr; + bool preset_only = false; // if true, only check & download remote preset (for router mode) +}; + // populate model paths (main model, mmproj, etc) from -hf if necessary // return true if the model is ready to use // throw an exception if there is an error that prevents the model from being used (e.g. network error, model not found, etc) @@ -137,7 +142,7 @@ void common_params_add_preset_options(std::vector & args); bool common_params_handle_models( common_params & params, llama_example curr_ex, - common_download_callback * callback = nullptr); + const common_params_handle_models_params & handle_params); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); diff --git a/common/chat.cpp b/common/chat.cpp index ded8440e66..cee6ad650a 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -90,41 +90,93 @@ std::string common_chat_msg::render_content(const std::string & delimiter) const return text; } -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims) { - if (delims.empty() || prompt.empty()) { - return {}; +common_chat_role common_chat_role_from_string(const std::string & role) { + if (role == "system") { return COMMON_CHAT_ROLE_SYSTEM; } + if (role == "assistant") { return COMMON_CHAT_ROLE_ASSISTANT; } + if (role == "user") { return COMMON_CHAT_ROLE_USER; } + if (role == "tool") { return COMMON_CHAT_ROLE_TOOL; } + return COMMON_CHAT_ROLE_UNKNOWN; +} + +const char * common_chat_role_to_string(common_chat_role role) { + switch (role) { + case COMMON_CHAT_ROLE_SYSTEM: return "system"; + case COMMON_CHAT_ROLE_ASSISTANT: return "assistant"; + case COMMON_CHAT_ROLE_USER: return "user"; + case COMMON_CHAT_ROLE_TOOL: return "tool"; + case COMMON_CHAT_ROLE_UNKNOWN: return ""; + } + return ""; +} + +json common_chat_msg_delimiters::to_json() const { + json result = json::array(); + for (const auto & d : delimiters) { + result.push_back({ + { "role", common_chat_role_to_string(d.role) }, + { "delimiter", d.delimiter }, + }); + } + return result; +} + +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const json & delimiters) { + common_chat_msg_delimiters result; + + if (!delimiters.is_array()) { + return result; } - auto parser = build_peg_parser([&](common_peg_parser_builder & p) { - std::vector all_delims; - std::vector tagged_messages; - - all_delims.reserve(delims.size()); - tagged_messages.reserve(delims.size()); - for (const auto & d : delims) { - all_delims.push_back(d.delimiter); + result.delimiters.reserve(delimiters.size()); + for (const auto & d : delimiters) { + if (!d.is_object()) { + continue; } - - auto any_delim = p.until_one_of(all_delims); - for (const auto & d : delims) { - tagged_messages.push_back(p.tag(d.role, p.literal(d.delimiter) + any_delim)); - } - - return any_delim + p.zero_or_more(p.choice(tagged_messages)) + p.end(); - }); - - common_peg_parse_context ctx(prompt); - const auto result = parser.parse(ctx); - if (!result.success()) { - return {}; + result.delimiters.push_back({ + common_chat_role_from_string(d.value("role", std::string())), + d.value("delimiter", std::string()), + }); } - std::vector spans; - ctx.ast.visit(result, [&](const common_peg_ast_node & node) { - if (!node.tag.empty()) { - spans.push_back({ node.tag, node.start, node.end - node.start }); + return result; +} + +void common_chat_msg_delimiters::tokenize(const llama_vocab * vocab) { + for (auto & d : delimiters) { + d.tokens = common_tokenize(vocab, d.delimiter, false, true); + } +} + +common_chat_msg_spans common_chat_msg_delimiters::split(const llama_tokens & tokens, const std::map & skips) const { + std::vector> matches; + + auto skip = skips.begin(); + for (size_t i = 0; i < tokens.size();) { + if (skip != skips.end() && i == skip->first) { + i += skip->second; + ++skip; + continue; } - }); + for (const auto & d : delimiters) { + if (i + d.tokens.size() > tokens.size()) { + continue; + } + if (std::equal(d.tokens.begin(), d.tokens.end(), tokens.begin() + i)) { + matches.emplace_back(d.role, i); + break; + } + } + i++; + } + + matches.emplace_back(COMMON_CHAT_ROLE_UNKNOWN, tokens.size()); + + common_chat_msg_spans spans; + for (size_t i = 0; i + 1 < matches.size(); i++) { + const auto & curr = matches[i]; + const auto & next = matches[i + 1]; + spans.add(curr.first, curr.second, next.second - curr.second); + } return spans; } @@ -1081,13 +1133,13 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp data.prompt = prompt; data.generation_prompt = common_chat_template_generation_prompt_impl(tmpl, inputs, /* messages_override= */ adjusted_messages); - data.message_spans = common_chat_split_by_role(prompt, { - { "assistant", "<|start|>assistant" }, - { "user", "<|start|>user" }, - { "system", "<|start|>developer" }, - { "system", "<|start|>system" }, - { "tool", "<|start|>functions" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, "<|start|>assistant" }, + { COMMON_CHAT_ROLE_USER, "<|start|>user" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>developer" }, + { COMMON_CHAT_ROLE_SYSTEM, "<|start|>system" }, + { COMMON_CHAT_ROLE_TOOL, "<|start|>functions" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_NATIVE; data.supports_thinking = true; @@ -1228,10 +1280,10 @@ static common_chat_params common_chat_params_init_gemma4(const common_chat_templ data.prompt += data.generation_prompt; } - data.message_spans = common_chat_split_by_role(data.prompt, { - { "user", "<|turn>user\n" }, - { "assistant", "<|turn>model\n" }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_USER, "<|turn>user" }, + { COMMON_CHAT_ROLE_ASSISTANT, "<|turn>model" }, + }; data.format = COMMON_CHAT_FORMAT_PEG_GEMMA4; data.supports_thinking = true; @@ -2030,15 +2082,15 @@ static common_chat_params common_chat_params_init_cohere2moe(const common_chat_t RESULT_START, RESULT_END, }; - // Split the rendered prompt into per-role message spans. Tool results are rendered with the + // Declare per-role message delimiters. Tool results are rendered with the // system token followed by <|START_TOOL_RESULT|>, so the "tool" delimiter must be listed before // the plain "system" one (it is a strict superset, and the role split tries delimiters in order). - data.message_spans = common_chat_split_by_role(data.prompt, { - { "assistant", GEN_PREFIX }, - { "user", TURN_START + USER }, - { "tool", TURN_START + SYSTEM + RESULT_START }, - { "system", TURN_START + SYSTEM }, - }); + data.message_delimiters = { + { COMMON_CHAT_ROLE_ASSISTANT, GEN_PREFIX }, + { COMMON_CHAT_ROLE_USER, TURN_START + USER }, + { COMMON_CHAT_ROLE_TOOL, TURN_START + SYSTEM + RESULT_START }, + { COMMON_CHAT_ROLE_SYSTEM, TURN_START + SYSTEM }, + }; auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; @@ -2526,17 +2578,15 @@ static common_chat_params common_chat_templates_apply_jinja(const struct common_ autoparser.analyze_template(tmpl); auto auto_params = autoparser::peg_generator::generate_parser(tmpl, params, autoparser); - std::vector delimiters; + common_chat_msg_delimiters delimiters; if (!autoparser.assistant_start.empty()) { - delimiters.push_back({ "assistant", autoparser.assistant_start }); + delimiters.add(COMMON_CHAT_ROLE_ASSISTANT, autoparser.assistant_start); } if (!autoparser.user_start.empty()) { - delimiters.push_back({ "user", autoparser.user_start }); + delimiters.add(COMMON_CHAT_ROLE_USER, autoparser.user_start); } - if (!delimiters.empty()) { - auto_params.message_spans = common_chat_split_by_role(auto_params.prompt, delimiters); - } + auto_params.message_delimiters = std::move(delimiters); auto_params.supports_thinking = autoparser.reasoning.mode != autoparser::reasoning_mode::NONE; if (auto_params.supports_thinking) { diff --git a/common/chat.h b/common/chat.h index 5659cd42a0..7898f1623f 100644 --- a/common/chat.h +++ b/common/chat.h @@ -143,15 +143,75 @@ struct common_chat_msg_diff { } }; +enum common_chat_role { + COMMON_CHAT_ROLE_UNKNOWN, + COMMON_CHAT_ROLE_SYSTEM, + COMMON_CHAT_ROLE_ASSISTANT, + COMMON_CHAT_ROLE_USER, + COMMON_CHAT_ROLE_TOOL +}; + +common_chat_role common_chat_role_from_string(const std::string & role); +const char * common_chat_role_to_string(common_chat_role role); + struct common_chat_msg_span { - std::string role; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; std::size_t pos = 0; std::size_t len = 0; + + bool valid() const { + return role != COMMON_CHAT_ROLE_UNKNOWN; + } +}; + +struct common_chat_msg_spans { + std::vector spans; + + void add(common_chat_role role, size_t pos, size_t len) { + spans.push_back({ role, pos, len }); + } + + bool is_user_start(int32_t pos) const { + for (auto it = spans.begin(); it != spans.end(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER && pos == (int32_t) it->pos) { + return true; + } + } + return false; + } + + int32_t last_user_message_pos() const { + for (auto it = spans.rbegin(); it != spans.rend(); ++it) { + if (it->role == COMMON_CHAT_ROLE_USER) { + return (int32_t) it->pos; + } + } + return -1; + } }; struct common_chat_msg_delimiter { - std::string role; - std::string delimiter; + common_chat_role role = COMMON_CHAT_ROLE_UNKNOWN; + std::string delimiter; + llama_tokens tokens = {}; +}; + +struct common_chat_msg_delimiters { + std::vector delimiters; + + common_chat_msg_delimiters() = default; + common_chat_msg_delimiters(std::initializer_list delims) : delimiters(delims) {} + + void add(common_chat_role role, const std::string & delimiter) { + delimiters.push_back({ role, delimiter }); + } + + void tokenize(const llama_vocab * vocab); + + // split tokens into message spans. skips maps a start index to a length of a region to jump over without matching + common_chat_msg_spans split(const llama_tokens & tokens, const std::map & skips = {}) const; + + nlohmann::ordered_json to_json() const; }; struct common_chat_tool { @@ -219,7 +279,7 @@ struct common_chat_params { std::vector preserved_tokens; std::vector additional_stops; std::string parser; - std::vector message_spans; + common_chat_msg_delimiters message_delimiters; }; // per-message parsing syntax @@ -325,5 +385,4 @@ struct common_chat_prompt_preset { common_chat_prompt_preset common_chat_get_asr_prompt(const common_chat_templates * chat_templates); -std::vector common_chat_split_by_role(const std::string & prompt, const std::vector & delims); - +common_chat_msg_delimiters common_chat_msg_delimiters_parse(const nlohmann::ordered_json & delimiters); diff --git a/common/common.h b/common/common.h index 0b9dd49766..381c0306c3 100644 --- a/common/common.h +++ b/common/common.h @@ -609,7 +609,7 @@ struct common_params { bool cache_prompt = true; // whether to enable prompt caching bool cache_idle_slots = true; // save and clear idle slots upon starting a new task int32_t n_ctx_checkpoints = 32; // max number of context checkpoints per slot - int32_t checkpoint_min_step = 256; // minimum spacing between context checkpoints + int32_t checkpoint_min_step = 8192; // minimum spacing between context checkpoints int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; diff --git a/common/download.cpp b/common/download.cpp index f320462753..5b55c76a11 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -799,6 +799,7 @@ common_download_model_result common_download_model(const common_params_model & bool download_mmproj = opts.download_mmproj; bool download_mtp = opts.download_mtp; + bool preset_only = opts.preset_only; bool is_hf = !model.hf_repo.empty(); if (is_hf) { @@ -806,7 +807,8 @@ common_download_model_result common_download_model(const common_params_model & if (!hf.preset.path.empty()) { // if preset.ini exists, only download that file alone tasks.push_back({hf.preset.url, hf.preset.local_path}); - } else { + } else if (!preset_only) { + // only add other files if we're NOT in preset-only mode (normal run, non-router) for (const auto & f : hf.model_files) { tasks.push_back({f.url, f.local_path}); } diff --git a/common/download.h b/common/download.h index 8dbf07836f..755e34ea8c 100644 --- a/common/download.h +++ b/common/download.h @@ -55,6 +55,7 @@ struct common_download_opts { bool skip_download = false; // if true, only validation is performed, common_skip_download_exception may be thrown if the file is missing or invalid bool download_mmproj = false; bool download_mtp = false; + bool preset_only = false; // if true, only check & download remote preset (for router mode) common_download_callback * callback = nullptr; }; diff --git a/conversion/__init__.py b/conversion/__init__.py index 00192cf33a..c6af6f7318 100644 --- a/conversion/__init__.py +++ b/conversion/__init__.py @@ -96,6 +96,7 @@ TEXT_MODEL_MAP: dict[str, str] = { "GraniteMoeHybridForCausalLM": "granite", "GraniteMoeSharedForCausalLM": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "Grok1ForCausalLM": "grok", "GrokForCausalLM": "grok", "GroveMoeForCausalLM": "grovemoe", @@ -261,6 +262,7 @@ MMPROJ_MODEL_MAP: dict[str, str] = { "GlmasrModel": "ultravox", "Granite4VisionForConditionalGeneration": "granite", "GraniteSpeechForConditionalGeneration": "granite", + "GraniteSpeechPlusForConditionalGeneration": "granite", "HunYuanVLForConditionalGeneration": "hunyuan", "Idefics3ForConditionalGeneration": "smolvlm", "InternVisionModel": "internvl", diff --git a/conversion/granite.py b/conversion/granite.py index 53441fe570..8367ed225d 100644 --- a/conversion/granite.py +++ b/conversion/granite.py @@ -348,6 +348,34 @@ class GraniteSpeechMmprojModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GraniteSpeechPlusForConditionalGeneration") +class GraniteSpeechPlusMmprojModel(GraniteSpeechMmprojModel): + """Conversion for GraniteSpeechPlus - extends GraniteSpeech with feature layer concatenation""" + has_vision_encoder = False + has_audio_encoder = True + + def set_gguf_parameters(self): + assert self.hparams_audio is not None + super().set_gguf_parameters() + + # Add feature_layer if present in encoder config + if feature_layers := self.hparams_audio.get("cat_hidden_layers"): + self.gguf_writer.add_audio_feature_layers(feature_layers) + logger.info(f"gguf: audio feature_layers = {feature_layers}") + + # Validate projector dimension matches concatenated encoder output + hidden_dim = self.hparams_audio["hidden_dim"] + expected_dim = hidden_dim * (len(feature_layers) + 1) + projector_dim = self.global_config["projector_config"]["encoder_hidden_size"] + + if projector_dim != expected_dim: + raise ValueError( + f"Projector encoder_hidden_size ({projector_dim}) does not match " + f"expected concatenated dimension ({expected_dim}). " + f"Expected: hidden_dim ({hidden_dim}) * (len(feature_layers) + 1) = {expected_dim}" + ) + + @ModelBase.register("Granite4VisionForConditionalGeneration") class Granite4VisionMmprojModel(MmprojModel): has_vision_encoder = True diff --git a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl index 9703b693e5..f5c6fb3e84 100644 --- a/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl +++ b/ggml/src/ggml-opencl/kernels/gemv_noshuffle_q8_0_f32.cl @@ -174,7 +174,7 @@ __kernel void kernel_gemv_noshuffle_q8_0_f32( regA.s6 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 6)).x; regA.s7 = read_imageui(src0_q, (gid + k * BLOCK_STRIDE_A + LINE_STRIDE_A * 7)).x; - dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, regS, regB); + dequantizeBlockAccum_ns_sgbroadcast_1(totalSum, regA, convert_float(regS), regB); } // reduction in local memory, assumes #wave=4 diff --git a/ggml/src/ggml-vulkan/CMakeLists.txt b/ggml/src/ggml-vulkan/CMakeLists.txt index 2d9e85794a..5aeb6e97b1 100644 --- a/ggml/src/ggml-vulkan/CMakeLists.txt +++ b/ggml/src/ggml-vulkan/CMakeLists.txt @@ -108,6 +108,9 @@ if (Vulkan_FOUND) if (GGML_VULKAN_CHECK_RESULTS) add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + # the result-checking path computes a CPU reference graph via + # ggml_graph_compute_with_ctx(), which is defined in ggml-cpu + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() if (GGML_VULKAN_DEBUG) @@ -129,6 +132,8 @@ if (Vulkan_FOUND) if (GGML_VULKAN_RUN_TESTS) add_compile_definitions(GGML_VULKAN_RUN_TESTS) + # the test path also calls ggml_graph_compute_with_ctx() (ggml-cpu) + target_link_libraries(ggml-vulkan PRIVATE ggml-cpu) endif() # Set up toolchain for host compilation whether cross-compiling or not diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6f877f15ce..c00a2e9ee9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -905,11 +905,12 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key { ggml_type src0_type; ggml_type src1_type; int vectorized; + uint32_t num_cols; bool use_mmvq; bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && - use_mmvq == other.use_mmvq; + num_cols == other.num_cols && use_mmvq == other.use_mmvq; } }; @@ -919,6 +920,7 @@ struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.use_mmvq); return seed; } @@ -993,11 +995,12 @@ struct ggml_webgpu_mul_mat_id_pipeline_key { ggml_type src0_type; ggml_type src1_type; uint32_t n_experts; + uint32_t num_cols; int vectorized; bool operator==(const ggml_webgpu_mul_mat_id_pipeline_key & other) const { return src0_type == other.src0_type && src1_type == other.src1_type && n_experts == other.n_experts && - vectorized == other.vectorized; + num_cols == other.num_cols && vectorized == other.vectorized; } }; @@ -1007,6 +1010,7 @@ struct ggml_webgpu_mul_mat_id_pipeline_key_hash { ggml_webgpu_hash_combine(seed, key.src0_type); ggml_webgpu_hash_combine(seed, key.src1_type); ggml_webgpu_hash_combine(seed, key.n_experts); + ggml_webgpu_hash_combine(seed, key.num_cols); ggml_webgpu_hash_combine(seed, key.vectorized); return seed; } @@ -1107,7 +1111,7 @@ inline bool ggml_webgpu_can_use_mmvq(const ggml_tensor * src0, const ggml_tensor * src1, bool supports_dot_product, const std::string & vendor) { - if (src1->ne[1] == 1) { + if (src1->ne[1] <= 4) { bool supports_dp4a = vendor == "amd" || vendor == "intel" || vendor == "nvidia"; if (supports_dp4a && supports_dot_product) { switch (src1->type) { @@ -1889,6 +1893,7 @@ class ggml_webgpu_shader_lib { (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? 1 : 0; + key.num_cols = context.dst->ne[1]; key.use_mmvq = ggml_webgpu_can_use_mmvq(context.src0, context.src1, context.supports_dot_product, context.vendor); @@ -2004,6 +2009,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=") + std::to_string(key.num_cols)); auto processed = preprocessor.preprocess(shader_src, defines); auto decisions = std::make_shared(); @@ -2421,6 +2427,7 @@ class ggml_webgpu_shader_lib { if (key.vectorized) { variant += "_vectorized"; } + defines.push_back(std::string("NUM_COLS=1")); defines.push_back(std::string("N_EXPERTS=") + std::to_string(key.n_experts)); diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f71d1aee73..e8eafd185a 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1418,15 +1418,17 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & const size_t dst_offset = ggml_webgpu_tensor_offset(dst); const size_t q8_src1_align_offset = ROUNDUP_POW2( dst_offset + ggml_nbytes(dst), ctx->global_ctx->capabilities.limits.minStorageBufferOffsetAlignment); - const size_t q8_src1_binding_size = - ROUNDUP_POW2(src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), - WEBGPU_STORAGE_BUF_BINDING_MULT); + const size_t q8_src1_binding_size = ROUNDUP_POW2( + src1->ne[3] * src1->ne[2] * src1->ne[1] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)), + WEBGPU_STORAGE_BUF_BINDING_MULT); std::vector q8_params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), (uint32_t) src1->ne[0], + (uint32_t) src1->ne[1], (uint32_t) src1->ne[2], (uint32_t) src1->ne[3], }; @@ -1442,7 +1444,7 @@ static void ggml_webgpu_quantize_q8_dispatch(webgpu_context & uint32_t q8_wg_x = 1; uint32_t q8_wg_y = 1; const uint32_t wg_per_vec = (src0->ne[0] / 4 + (q8_wg_size - 1)) / q8_wg_size; - const uint32_t q8_total_wg = src1->ne[2] * src1->ne[3] * wg_per_vec; + const uint32_t q8_total_wg = src1->ne[1] * src1->ne[2] * src1->ne[3] * wg_per_vec; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; compute_2d_workgroups(q8_total_wg, max_wg_per_dim, q8_wg_x, q8_wg_y); @@ -1456,7 +1458,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst) { // Determine if this is a mat-vec operation - bool is_vec = (dst->ne[1] == 1); + bool use_mat_vec = (dst->ne[1] <= 4); // use MMVQ path for mat-vec bool use_mmvq = ggml_webgpu_can_use_mmvq(src0, src1, ctx->global_ctx->capabilities.supports_dot_product, @@ -1482,7 +1484,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, webgpu_pipeline pipeline; std::vector dispatches; - if (is_vec) { + if (use_mat_vec) { if (use_mmvq) { ggml_webgpu_quantize_q8_dispatch(ctx, src0, src1, dst, dispatches); } @@ -1529,7 +1531,7 @@ static webgpu_encoded_op ggml_webgpu_mul_mat(webgpu_context & ctx, uint32_t wg_y = 1; const uint32_t max_wg_per_dim = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - if (is_vec) { + if (use_mat_vec) { auto * decisions = static_cast(pipeline.context.get()); uint32_t batches = dst->ne[2] * dst->ne[3]; @@ -3691,8 +3693,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alloc_size(ggml_backend_buffer ggml_webgpu_can_use_mmvq(src0, src1, ctx->webgpu_global_ctx->capabilities.supports_dot_product, ctx->webgpu_global_ctx->vendor); if (use_mmvq) { - const size_t q8_src1_size = - src1->ne[3] * src1->ne[2] * (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); + const size_t q8_src1_size = src1->ne[3] * src1->ne[2] * src1->ne[1] * + (36 /* sizeof(q8_1) */ * (src1->ne[0] / /* block_size */ 32)); res = ROUNDUP_POW2(res + q8_src1_size + ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment, WEBGPU_STORAGE_BUF_BINDING_MULT); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl index 6ff9bcf2df..78ae955e6b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_id_vec.wgsl @@ -103,7 +103,7 @@ fn main( #ifdef USE_SUBGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); + let subgroup_total = subgroupAdd(acc[0][row]); if (subgroup_invocation_id == 0u) { partial_sums[partial_index(row, subgroup_id)] = subgroup_total; } @@ -126,7 +126,7 @@ fn main( #ifdef USE_WORKGROUP_REDUCTION for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; + partial_sums[partial_index(row, thread_id)] = acc[0][row]; } workgroupBarrier(); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index f0a7fbd059..ebdf09513e 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -91,61 +91,67 @@ fn main( let dst_idx_base = params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row_base; #ifdef MMVQ - let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * (params.k / 32u); + let src1q_idx_base = (src13_idx * params.bs02 * params.broadcast2 + src12_idx) * params.n * (params.k / 32u); let acc = accumulate_vec_q_dot(thread_id, row_base, src0_batch_offset, src1q_idx_base); #else let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; let acc = accumulate_vec_dot(thread_id, row_base, src0_batch_offset, src1_idx_base); #endif + for (var col = 0u;col < NUM_COLS;col += 1) { + #ifdef USE_SUBGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - let subgroup_total = subgroupAdd(acc[row]); - if (subgroup_invocation_id == 0u) { - partial_sums[partial_index(row, subgroup_id)] = subgroup_total; - } - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + let subgroup_total = subgroupAdd(acc[col][row]); + if (subgroup_invocation_id == 0u) { + partial_sums[partial_index(row, subgroup_id)] = subgroup_total; + } + } - workgroupBarrier(); + workgroupBarrier(); - for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { - let output_row = row_base + row; - var row_acc = 0.0f; - for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { - row_acc += partial_sums[partial_index(row, k)]; - } - let row_total = subgroupAdd(row_acc); - if (subgroup_invocation_id == 0) { - dst[dst_idx_base + row] = row_total; - } - } + for (var row = subgroup_id; (row < OUTPUTS_PER_WG) && (row_base + row < params.m); row += num_subgroups) { + let output_row = row_base + row; + var row_acc = 0.0f; + for (var k = subgroup_invocation_id; k < num_subgroups; k += subgroup_size) { + row_acc += partial_sums[partial_index(row, k)]; + } + let row_total = subgroupAdd(row_acc); + if (subgroup_invocation_id == 0) { + dst[dst_idx_base + col * params.m + row] = row_total; + } + } #endif #ifdef USE_WORKGROUP_REDUCTION - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] = acc[row]; - } + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] = acc[col][row]; + } + + workgroupBarrier(); + + var stride = WG_SIZE / 2u; + + while (stride > 0) { + if (thread_id < stride) { + for (var row = 0u; row < OUTPUTS_PER_WG; row++) { + partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; + } + } + + workgroupBarrier(); + stride = stride / 2; + } + + if (thread_id < OUTPUTS_PER_WG) { + let output_row = row_base + thread_id; + if (output_row < params.m) { + dst[dst_idx_base + col * params.m + thread_id] = partial_sums[partial_index(thread_id, 0)]; + } + } +#endif workgroupBarrier(); - var stride = WG_SIZE / 2u; - - while (stride > 0) { - if (thread_id < stride) { - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { - partial_sums[partial_index(row, thread_id)] += partial_sums[partial_index(row, thread_id + stride)]; - } - } - - workgroupBarrier(); - stride = stride / 2; } - - if (thread_id < OUTPUTS_PER_WG) { - let output_row = row_base + thread_id; - if (output_row < params.m) { - dst[dst_idx_base + thread_id] = partial_sums[partial_index(thread_id, 0)]; - } - } -#endif } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl index 08753b9d64..b0703fe906 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_acc.tmpl @@ -32,8 +32,8 @@ fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { #endif #ifdef MUL_ACC_FLOAT -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let k_vec = params.k / VEC_SIZE; let src1_idx_base_vec = src1_idx_base / VEC_SIZE; @@ -41,12 +41,18 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src // Each thread walks K, loads from the vector, and updates // a small block of output rows held in registers. for (var k = thread_id; k < k_vec; k += WG_SIZE) { - let x = src1[src1_idx_base_vec + k]; + var x_vals: array; + for (var col = 0u;col < NUM_COLS;col += 1) { + x_vals[col] = src1[src1_idx_base_vec + col * (params.stride_11 / VEC_SIZE) + k]; + } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let src0_idx = (src0_batch_offset + output_row * params.stride_01) / VEC_SIZE + k; - acc[row] += inner_dot(src0[src0_idx], x); + let w = src0[src0_idx]; + for (var col = 0u;col < NUM_COLS;col += 1) { + acc[col][row] += inner_dot(w, x_vals[col]); + } } } } @@ -60,30 +66,33 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 16 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let q_byte = load_u32_at_src0(block_byte_base + 2u + thread_within_block) & 0xFFu; - var row_sum = 0.0; - for (var bit = 0u; bit < 8u; bit++) { - let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); - row_sum += w * x_block[bit]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var bit = 0u; bit < 8u; bit++) { + let w = select(-d, d, ((q_byte >> bit) & 1u) != 0u); + row_sum += w * x_block[col][bit]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -97,35 +106,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; - let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = (f32(q_byte & 0xFu) - 8.0) * d; + let q_hi = (f32((q_byte >> 4u) & 0xFu) - 8.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -139,36 +150,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 20 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(q_byte & 0xFu) * d + m; - let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(q_byte & 0xFu) * d + m; + let q_hi = f32((q_byte >> 4u) & 0xFu) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -182,19 +195,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 22 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -203,18 +217,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 2u); let q_packed = load_u32_at_src0(block_byte_base + 6u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; - let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = (f32((q_byte & 0xFu) | qh_lo) - 16.0) * d; + let q_hi = (f32(((q_byte >> 4u) & 0xFu) | qh_hi) - 16.0) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -228,19 +243,20 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 24 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -250,18 +266,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_packed = load_u32_at_src0(block_byte_base + 4u); let q_packed = load_u32_at_src0(block_byte_base + 8u + 4u * thread_within_block); let qh_shift = thread_within_block * 4u; - var row_sum = 0.0; - - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; - let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; - let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; - let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let qh_lo = ((qh_packed >> (qh_shift + byte_idx)) << 4u) & 0x10u; + let qh_hi = (qh_packed >> (qh_shift + byte_idx + 12u)) & 0x10u; + let q_lo = f32((q_byte & 0xFu) | qh_lo) * d + m; + let q_hi = f32(((q_byte >> 4u) & 0xFu) | qh_hi) * d + m; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -275,33 +292,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 34 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 2u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -315,34 +337,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 36 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * ELEMS_PER_THREAD; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); let m = f32(load_f16_at_src0(block_byte_base + 2u)); - var row_sum = 0.0; - + var q_packed: array; for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { - let q_packed = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_val = f32(get_byte_i32(q_packed, byte_idx)) * d + m; - row_sum += q_val * x_block[packed_idx * 4u + byte_idx]; - } + q_packed[packed_idx] = load_u32_at_src0(block_byte_base + 4u + 4u * (thread_within_block * 2u + packed_idx)); + } + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var packed_idx = 0u; packed_idx < ELEMS_PER_THREAD / 4u; packed_idx++) { + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_val = f32(get_byte_i32(q_packed[packed_idx], byte_idx)) * d + m; + row_sum += q_val * x_block[col][packed_idx * 4u + byte_idx]; + } + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -355,8 +382,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 84 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -379,14 +406,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 64u + i]); - x_block[i + 12u] = f32(src1[x_base + 96u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 64u + i]); + x_block[col][i + 12u] = f32(src1[x_base + col * params.stride_11 + 96u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -404,30 +432,32 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qs0 = q_u32 & 0xFFFFu; let qs1 = q_u32 >> 16u; - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - var acc1 = vec4(0.0, 0.0, 0.0, 0.0); - var acc2 = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + var acc1 = vec4(0.0, 0.0, 0.0, 0.0); + var acc2 = vec4(0.0, 0.0, 0.0, 0.0); - sumy[0] = x_block[0] + x_block[1] + x_block[2] + x_block[3]; - sumy[1] = x_block[4] + x_block[5] + x_block[6] + x_block[7]; - sumy[2] = x_block[8] + x_block[9] + x_block[10] + x_block[11]; - sumy[3] = x_block[12] + x_block[13] + x_block[14] + x_block[15]; + sumy[0] = x_block[col][0] + x_block[col][1] + x_block[col][2] + x_block[col][3]; + sumy[1] = x_block[col][4] + x_block[col][5] + x_block[col][6] + x_block[col][7]; + sumy[2] = x_block[col][8] + x_block[col][9] + x_block[col][10] + x_block[col][11]; + sumy[3] = x_block[col][12] + x_block[col][13] + x_block[col][14] + x_block[col][15]; - acc1[0] = x_block[0] * f32(qs0 & 0x0003u) + x_block[2] * f32(qs1 & 0x0003u); - acc2[0] = x_block[1] * f32(qs0 & 0x0300u) + x_block[3] * f32(qs1 & 0x0300u); - acc1[1] = x_block[4] * f32(qs0 & 0x000Cu) + x_block[6] * f32(qs1 & 0x000Cu); - acc2[1] = x_block[5] * f32(qs0 & 0x0C00u) + x_block[7] * f32(qs1 & 0x0C00u); - acc1[2] = x_block[8] * f32(qs0 & 0x0030u) + x_block[10] * f32(qs1 & 0x0030u); - acc2[2] = x_block[9] * f32(qs0 & 0x3000u) + x_block[11] * f32(qs1 & 0x3000u); - acc1[3] = x_block[12] * f32(qs0 & 0x00C0u) + x_block[14] * f32(qs1 & 0x00C0u); - acc2[3] = x_block[13] * f32(qs0 & 0xC000u) + x_block[15] * f32(qs1 & 0xC000u); + acc1[0] = x_block[col][0] * f32(qs0 & 0x0003u) + x_block[col][2] * f32(qs1 & 0x0003u); + acc2[0] = x_block[col][1] * f32(qs0 & 0x0300u) + x_block[col][3] * f32(qs1 & 0x0300u); + acc1[1] = x_block[col][4] * f32(qs0 & 0x000Cu) + x_block[col][6] * f32(qs1 & 0x000Cu); + acc2[1] = x_block[col][5] * f32(qs0 & 0x0C00u) + x_block[col][7] * f32(qs1 & 0x0C00u); + acc1[2] = x_block[col][8] * f32(qs0 & 0x0030u) + x_block[col][10] * f32(qs1 & 0x0030u); + acc2[2] = x_block[col][9] * f32(qs0 & 0x3000u) + x_block[col][11] * f32(qs1 & 0x3000u); + acc1[3] = x_block[col][12] * f32(qs0 & 0x00C0u) + x_block[col][14] * f32(qs1 & 0x00C0u); + acc2[3] = x_block[col][13] * f32(qs0 & 0xC000u) + x_block[col][15] * f32(qs1 & 0xC000u); - acc[row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + - (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + - (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + - (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) - - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + - sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + acc[col][row] += dall * ((acc1[0] + (1.0/256.0) * acc2[0]) * f32(sc0 & 0xFu) + + (acc1[1] + (1.0/256.0) * acc2[1]) * f32(sc2 & 0xFu) / 4.0 + + (acc1[2] + (1.0/256.0) * acc2[2]) * f32(sc4 & 0xFu) / 16.0 + + (acc1[3] + (1.0/256.0) * acc2[3]) * f32(sc6 & 0xFu) / 64.0) + - dmin * (sumy[0] * f32(sc0 & 0xF0u) + sumy[1] * f32(sc2 & 0xF0u) + + sumy[2] * f32(sc4 & 0xF0u) + sumy[3] * f32(sc6 & 0xF0u)); + } } } } @@ -440,8 +470,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -485,12 +515,13 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 8u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 8u] = f32(src1[x_base + 32u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 8u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 8u] = f32(src1[x_base + col * params.stride_11 + 32u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -516,28 +547,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let h_u32_0 = load_u32_at_src0(block_byte_base + h_byte + 0u); let h_u32_1 = load_u32_at_src0(block_byte_base + h_byte + 4u); - var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; - var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; + for (var col = 0u;col < NUM_COLS;col += 1) { + var s1 = 0.0; var s2 = 0.0; var s3 = 0.0; + var s4 = 0.0; var s5 = 0.0; var s6 = 0.0; - for (var l = 0u; l < 8u; l += 2u) { - let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); - let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); - let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); - let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); + for (var l = 0u; l < 8u; l += 2u) { + let q_u32 = select(q_u32_0, q_u32_1, l >= 4u); + let qs = select(q_u32 & 0xFFFFu, q_u32 >> 16u, (l & 2u) != 0u); + let h_u32 = select(h_u32_0, h_u32_1, l >= 4u); + let hv = select(h_u32 & 0xFFFFu, h_u32 >> 16u, (l & 2u) != 0u); - s1 += x_block[l + 0u] * f32(qs & qm0); - s2 += x_block[l + 1u] * f32(qs & qm1); - s3 += select(0.0, x_block[l + 0u], (hv & hm0) == 0u) + - select(0.0, x_block[l + 1u], (hv & hm1) == 0u); - s4 += x_block[l + 8u] * f32(qs & qm2); - s5 += x_block[l + 9u] * f32(qs & qm3); - s6 += select(0.0, x_block[l + 8u], (hv & hm2) == 0u) + - select(0.0, x_block[l + 9u], (hv & hm3) == 0u); + s1 += x_block[col][l + 0u] * f32(qs & qm0); + s2 += x_block[col][l + 1u] * f32(qs & qm1); + s3 += select(0.0, x_block[col][l + 0u], (hv & hm0) == 0u) + + select(0.0, x_block[col][l + 1u], (hv & hm1) == 0u); + s4 += x_block[col][l + 8u] * f32(qs & qm2); + s5 += x_block[col][l + 9u] * f32(qs & qm3); + s6 += select(0.0, x_block[col][l + 8u], (hv & hm2) == 0u) + + select(0.0, x_block[col][l + 9u], (hv & hm3) == 0u); + } + + let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); + let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); + acc[col][row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } - - let d1 = d * (s1 + (1.0/256.0) * s2 - s3 * v1); - let d2 = d * (s4 + (1.0/256.0) * s5 - s6 * v2); - acc[row] += (d1 * scale0 + 0.25 * d2 * scale1) / f32(1u << shift); } } } @@ -550,8 +583,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 144 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -573,12 +606,15 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } for (var row = 0u; row < OUTPUTS_PER_WG; row++) { @@ -613,23 +649,25 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q1_u32 = load_u32_at_src0_aligned(block_byte_base + 16u + q_offset); let q2_u32 = load_u32_at_src0_aligned(block_byte_base + 80u + q_offset); - var dot = vec4(0.0, 0.0, 0.0, 0.0); - var sumx = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - dot[0] += x_block[i] * f32(q1b & 0x0Fu); - dot[1] += x_block[i + 4u] * f32(q1b >> 4u); - dot[2] += x_block[i + 8u] * f32(q2b & 0x0Fu); - dot[3] += x_block[i + 12u] * f32(q2b >> 4u); - sumx[0] += x_block[i]; - sumx[1] += x_block[i + 4u]; - sumx[2] += x_block[i + 8u]; - sumx[3] += x_block[i + 12u]; - } + for (var col = 0u;col < NUM_COLS;col += 1) { + var dot = vec4(0.0, 0.0, 0.0, 0.0); + var sumx = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + dot[0] += x_block[col][i] * f32(q1b & 0x0Fu); + dot[1] += x_block[col][i + 4u] * f32(q1b >> 4u); + dot[2] += x_block[col][i + 8u] * f32(q2b & 0x0Fu); + dot[3] += x_block[col][i + 12u] * f32(q2b >> 4u); + sumx[0] += x_block[col][i]; + sumx[1] += x_block[col][i + 4u]; + sumx[2] += x_block[col][i + 8u]; + sumx[3] += x_block[col][i + 12u]; + } - acc[row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) - - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + acc[col][row] += d * (dot[0] * scale0 + dot[1] * scale1 + dot[2] * scale2 + dot[3] * scale3) + - dmin * (sumx[0] * min0 + sumx[1] * min1 + sumx[2] * min2 + sumx[3] * min3); + } } } } @@ -642,8 +680,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 176 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -671,14 +709,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 4u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + 32u + i]); - x_block[i + 8u] = f32(src1[x_base + 128u + i]); - x_block[i + 12u] = f32(src1[x_base + 160u + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var i = 0u; i < 4u; i++) { + x_block[col][i] = f32(src1[col_base + i]); + x_block[col][i + 4u] = f32(src1[col_base + 32u + i]); + x_block[col][i + 8u] = f32(src1[col_base + 128u + i]); + x_block[col][i + 12u] = f32(src1[col_base + 160u + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -712,37 +752,39 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q2_u32 = load_u32_at_src0_aligned(block_byte_base + q_offset + 64u); let qh_u32 = load_u32_at_src0_aligned(block_byte_base + qh_offset); - var vals = vec4(0.0, 0.0, 0.0, 0.0); - var sumy = vec4(0.0, 0.0, 0.0, 0.0); - for (var i = 0u; i < 4u; i++) { - let q1b = byte_of(q1_u32, i); - let q2b = byte_of(q2_u32, i); - let qhb = byte_of(qh_u32, i); + for (var col = 0u;col < NUM_COLS;col += 1) { + var vals = vec4(0.0, 0.0, 0.0, 0.0); + var sumy = vec4(0.0, 0.0, 0.0, 0.0); + for (var i = 0u; i < 4u; i++) { + let q1b = byte_of(q1_u32, i); + let q2b = byte_of(q2_u32, i); + let qhb = byte_of(qh_u32, i); - let yl0 = x_block[i]; - let yl8 = x_block[i + 4u]; - let yh0 = x_block[i + 8u]; - let yh8 = x_block[i + 12u]; + let yl0 = x_block[col][i]; + let yl8 = x_block[col][i + 4u]; + let yh0 = x_block[col][i + 8u]; + let yh8 = x_block[col][i + 12u]; - sumy[0] += yl0; - sumy[1] += yl8; - sumy[2] += yh0; - sumy[3] += yh8; + sumy[0] += yl0; + sumy[1] += yl8; + sumy[2] += yh0; + sumy[3] += yh8; - let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); - let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); - let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); - let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); + let q0 = f32((q1b & 0x0Fu) | select(0u, 0x10u, (qhb & hm1) != 0u)); + let q1 = f32((q1b >> 4u) | select(0u, 0x10u, (qhb & hm2) != 0u)); + let q2 = f32((q2b & 0x0Fu) | select(0u, 0x10u, (qhb & hm3) != 0u)); + let q3 = f32((q2b >> 4u) | select(0u, 0x10u, (qhb & hm4) != 0u)); - vals[0] += yl0 * q0; - vals[1] += yl8 * q1; - vals[2] += yh0 * q2; - vals[3] += yh8 * q3; + vals[0] += yl0 * q0; + vals[1] += yl8 * q1; + vals[2] += yh0 * q2; + vals[3] += yh8 * q3; + } + + acc[col][row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) + - dmin * (sumy[0] * m0 + sumy[1] * m1 + + sumy[2] * m4 + sumy[3] * m5); } - - acc[row] += d * (f0 * vals[0] + f1 * vals[1] + f4 * vals[2] + f5 * vals[3]) - - dmin * (sumy[0] * m0 + sumy[1] * m1 + - sumy[2] * m4 + sumy[3] * m5); } } } @@ -755,8 +797,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 210 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -777,14 +819,16 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var l = 0u; l < 4u; l++) { - x_block[l] = f32(src1[x_base + l]); - x_block[l + 4u] = f32(src1[x_base + 32u + l]); - x_block[l + 8u] = f32(src1[x_base + 64u + l]); - x_block[l + 12u] = f32(src1[x_base + 96u + l]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + let col_base = x_base + col * params.stride_11; + for (var l = 0u; l < 4u; l++) { + x_block[col][l] = f32(src1[col_base + l]); + x_block[col][l + 4u] = f32(src1[col_base + 32u + l]); + x_block[col][l + 8u] = f32(src1[col_base + 64u + l]); + x_block[col][l + 12u] = f32(src1[col_base + 96u + l]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -802,26 +846,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc4 = sbyte_of(sc_u32_1, sc_byte_pos); let sc6 = sbyte_of(sc_u32_1, sc_byte_pos + 2u); - var sums = vec4(0.0, 0.0, 0.0, 0.0); + for (var col = 0u;col < NUM_COLS;col += 1) { + var sums = vec4(0.0, 0.0, 0.0, 0.0); - for (var l = 0u; l < 4u; l++) { - let q1b = byte_of(ql1_u32, l); - let q2b = byte_of(ql2_u32, l); - let qhb = byte_of(qh_u32, l); + for (var l = 0u; l < 4u; l++) { + let q1b = byte_of(ql1_u32, l); + let q2b = byte_of(ql2_u32, l); + let qhb = byte_of(qh_u32, l); - let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); - let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); - let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); - let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); + let dq0 = f32(i32((q1b & 0x0Fu) | ((qhb & 0x03u) << 4u)) - 32); + let dq1 = f32(i32((q2b & 0x0Fu) | ((qhb & 0x0Cu) << 2u)) - 32); + let dq2 = f32(i32((q1b >> 4u) | (qhb & 0x30u)) - 32); + let dq3 = f32(i32((q2b >> 4u) | ((qhb & 0xC0u) >> 2u)) - 32); - sums[0] += x_block[l] * dq0; - sums[1] += x_block[l + 4u] * dq1; - sums[2] += x_block[l + 8u] * dq2; - sums[3] += x_block[l + 12u] * dq3; + sums[0] += x_block[col][l] * dq0; + sums[1] += x_block[col][l + 4u] * dq1; + sums[2] += x_block[col][l + 8u] * dq2; + sums[3] += x_block[col][l + 12u] * dq3; + } + + acc[col][row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + + sums[2] * f32(sc4) + sums[3] * f32(sc6)); } - - acc[row] += d * (sums[0] * f32(sc0) + sums[1] * f32(sc2) + - sums[2] * f32(sc4) + sums[3] * f32(sc6)); } } } @@ -834,8 +880,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 50 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -850,11 +896,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -866,20 +913,22 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); let qs_w = load_u32_at_src0(block_byte_base + 2u + sub_blk * 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let ig = (qs_byte | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -892,8 +941,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 56 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -908,11 +957,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -936,26 +986,28 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let qh_lo = qh & 0xFFu; let qh_hi = (qh >> 8u) & 0xFFu; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); - let sub_scale = (sc_u16 >> bit_off) & 0x7u; - let dl = d * f32(2u * sub_scale + 1u); - let qh_byte = select(qh_lo, qh_hi, l >= 2u); - let ll2 = l % 2u; - let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); - let ig = grid_idx * 8u; - let gw = iq1_grid[ig / 16u]; - let bit_base = (ig % 16u) * 2u; - for (var j = 0u; j < 8u; j++) { - let g = (gw >> (bit_base + j * 2u)) & 3u; - let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); - row_sum += dl * (gs + delta) * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let bit_off = 6u * (sub_blk % 2u) + 3u * (l / 2u); + let sub_scale = (sc_u16 >> bit_off) & 0x7u; + let dl = d * f32(2u * sub_scale + 1u); + let qh_byte = select(qh_lo, qh_hi, l >= 2u); + let ll2 = l % 2u; + let grid_idx = get_byte(qs_w, l) | (((qh_byte >> (4u * ll2)) & 7u) << 8u); + let delta = select(IQ1_DELTA, -IQ1_DELTA, ((qh_byte >> (3u + 4u * ll2)) & 1u) != 0u); + let ig = grid_idx * 8u; + let gw = iq1_grid[ig / 16u]; + let bit_base = (ig % 16u) * 2u; + for (var j = 0u; j < 8u; j++) { + let g = (gw >> (bit_base + j * 2u)) & 3u; + let gs = select(f32(g), f32(g) - 4.0, (g & 2u) != 0u); + row_sum += dl * (gs + delta) * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -968,8 +1020,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 66 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -984,11 +1036,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -999,22 +1052,24 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux_hi >> 28u; let db = d * (0.5 + f32(ls)) * 0.25; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; - let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xxs_grid[grid_idx * 2u]; - let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let grid_idx = (aux_lo >> (8u * l)) & 0xFFu; + let signs_idx = (aux_hi >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xxs_grid[grid_idx * 2u]; + let gw_hi = iq2xxs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1027,8 +1082,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 74 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1043,11 +1098,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1058,27 +1114,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let scales_word = load_u32_at_src0(block_byte_base + 66u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(scales_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let half2 = (l % 2u) * 16u; - let qs_val = (qs_word >> half2) & 0xFFFFu; - let grid_idx = qs_val & 0x1FFu; - let signs_idx = (qs_val >> 9u) & 0x7Fu; - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let gw_lo = iq2xs_grid[grid_idx * 2u]; - let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let half2 = (l % 2u) * 16u; + let qs_val = (qs_word >> half2) & 0xFFFFu; + let grid_idx = qs_val & 0x1FFu; + let signs_idx = (qs_val >> 9u) & 0x7Fu; + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let gw_lo = iq2xs_grid[grid_idx * 2u]; + let gw_hi = iq2xs_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1091,8 +1149,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 82 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1107,11 +1165,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1124,24 +1183,26 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sc_word = load_u32_at_src0(block_byte_base + 74u + (sub_blk / 4u) * 4u); let scales_byte = get_byte(sc_word, sub_blk % 4u); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_byte = get_byte(qs_w, l); - let sign_byte = get_byte(sg_w, l); - let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); - let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; - let db = d * (0.5 + f32(sub_scale)) * 0.25; - let gw_lo = iq2s_grid[grid_idx * 2u]; - let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; - for (var j = 0u; j < 8u; j++) { - let gw = select(gw_hi, gw_lo, j < 4u); - let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); - let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - row_sum += db * b * s * x_block[ll * 8u + j]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_byte = get_byte(qs_w, l); + let sign_byte = get_byte(sg_w, l); + let grid_idx = qs_byte | (((qh_byte >> (2u * l)) & 3u) << 8u); + let sub_scale = (scales_byte >> (4u * (l / 2u))) & 0xFu; + let db = d * (0.5 + f32(sub_scale)) * 0.25; + let gw_lo = iq2s_grid[grid_idx * 2u]; + let gw_hi = iq2s_grid[grid_idx * 2u + 1u]; + for (var j = 0u; j < 8u; j++) { + let gw = select(gw_hi, gw_lo, j < 4u); + let b = f32((gw >> ((j & 3u) * 8u)) & 0xFFu); + let s = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + row_sum += db * b * s * x_block[col][ll * 8u + j]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1154,8 +1215,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 98 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1170,11 +1231,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1186,27 +1248,29 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let ls = aux >> 28u; let db = d * (0.5 + f32(ls)) * 0.5; - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let signs_idx = (aux >> (7u * l)) & 0x7Fu; - let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; - let grid1 = iq3xxs_grid[grid_idx_0]; - let grid2 = iq3xxs_grid[grid_idx_1]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let grid_idx_0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let grid_idx_1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let signs_idx = (aux >> (7u * l)) & 0x7Fu; + let signs = (ksigns_iq2xs[signs_idx / 4u] >> ((signs_idx % 4u) * 8u)) & 0xFFu; + let grid1 = iq3xxs_grid[grid_idx_0]; + let grid2 = iq3xxs_grid[grid_idx_1]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((signs >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((signs >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1219,8 +1283,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 110 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1235,11 +1299,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1255,28 +1320,30 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let sub_scale = (scales_byte >> (4u * (sub_blk % 2u))) & 0xFu; let db = d * (1.0 + 2.0 * f32(sub_scale)); - var row_sum = 0.0; - for (var ll = 0u; ll < 2u; ll++) { - let l = slot0 + ll; - let qs_word = select(qs_hi, qs_lo, l < 2u); - let byte_pos = (l % 2u) * 2u; - let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; - let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; - let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); - let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); - let sign_byte = get_byte(sg_w, l); - let grid1 = iq3s_grid[grid_idx_1]; - let grid2 = iq3s_grid[grid_idx_2]; - for (var j = 0u; j < 4u; j++) { - let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); - let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); - let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); - let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); - row_sum += db * b1 * s1 * x_block[ll * 8u + j]; - row_sum += db * b2 * s2 * x_block[ll * 8u + j + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var ll = 0u; ll < 2u; ll++) { + let l = slot0 + ll; + let qs_word = select(qs_hi, qs_lo, l < 2u); + let byte_pos = (l % 2u) * 2u; + let qs0 = (qs_word >> (byte_pos * 8u)) & 0xFFu; + let qs1 = (qs_word >> ((byte_pos + 1u) * 8u)) & 0xFFu; + let grid_idx_1 = qs0 | (((qh_byte >> (2u * l)) & 1u) << 8u); + let grid_idx_2 = qs1 | (((qh_byte >> (2u * l + 1u)) & 1u) << 8u); + let sign_byte = get_byte(sg_w, l); + let grid1 = iq3s_grid[grid_idx_1]; + let grid2 = iq3s_grid[grid_idx_2]; + for (var j = 0u; j < 4u; j++) { + let b1 = f32((grid1 >> (j * 8u)) & 0xFFu); + let b2 = f32((grid2 >> (j * 8u)) & 0xFFu); + let s1 = select(1.0, -1.0, ((sign_byte >> j) & 1u) != 0u); + let s2 = select(1.0, -1.0, ((sign_byte >> (j + 4u)) & 1u) != 0u); + row_sum += db * b1 * s1 * x_block[col][ll * 8u + j]; + row_sum += db * b2 * s2 * x_block[col][ll * 8u + j + 4u]; + } } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1290,35 +1357,37 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 18 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4u; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4u] = f32(src1[x_base + i + 16u]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4u] = f32(src1[x_base + col * params.stride_11 + i + 16u]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let d = f32(load_f16_at_src0(block_byte_base)); - var row_sum = 0.0; - let q_packed = load_u32_at_src0(block_byte_base + 2u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; - let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_iq4nl[q_byte & 0xFu]) * d; + let q_hi = f32(kvalues_iq4nl[(q_byte >> 4u) & 0xFu]) * d; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1331,8 +1400,8 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE 256 #define BLOCK_SIZE_BYTES 136 #define THREADS_PER_BLOCK 16 -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; let block_group = thread_id / THREADS_PER_BLOCK; @@ -1346,11 +1415,12 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src for (var block = block_group; block < num_blocks; block += num_block_groups) { let x_base = src1_idx_base + block * BLOCK_SIZE + y_offset; - var x_block: array; - for (var i = 0u; i < 16u; i++) { - x_block[i] = f32(src1[x_base + i]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < 16u; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { @@ -1370,17 +1440,19 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src let q_w2 = load_u32_at_src0(block_byte_base + qs_byte_off + 8u); let q_w3 = load_u32_at_src0(block_byte_base + qs_byte_off + 12u); - var row_sum = 0.0; - for (var i = 0u; i < 16u; i++) { - let q_word = select( - select(q_w0, q_w1, i >= 4u), - select(q_w2, q_w3, i >= 12u), - i >= 8u); - let q_byte = get_byte(q_word, i % 4u); - let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); - row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[i]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var i = 0u; i < 16u; i++) { + let q_word = select( + select(q_w0, q_w1, i >= 4u), + select(q_w2, q_w3, i >= 12u), + i >= 8u); + let q_byte = get_byte(q_word, i % 4u); + let nib = select(q_byte & 0xFu, (q_byte >> 4u) & 0xFu, half == 1u); + row_sum += f32(kvalues_iq4nl[nib]) * dl * x_block[col][i]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } @@ -1394,35 +1466,38 @@ fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src #define BLOCK_SIZE_BYTES 17 #define THREADS_PER_BLOCK 4 #define ELEMS_PER_THREAD (BLOCK_SIZE/THREADS_PER_BLOCK) -fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; let thread_within_block = thread_id % 4; for (var block = thread_id/THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE/THREADS_PER_BLOCK) { let x_base = src1_idx_base + block * BLOCK_SIZE + thread_within_block * 4; - var x_block: array; - for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { - x_block[i] = f32(src1[x_base + i]); - x_block[i + 4] = f32(src1[x_base + i + 16]); + var x_block: array, NUM_COLS>; + for (var col = 0u; col < NUM_COLS;col += 1) { + for (var i = 0u; i < ELEMS_PER_THREAD / 2; i++) { + x_block[col][i] = f32(src1[x_base + col * params.stride_11 + i]); + x_block[col][i + 4] = f32(src1[x_base + col * params.stride_11 + i + 16]); + } } - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; let eu8 = get_byte(load_u32_at_src0(block_byte_base), 0); let e = ldexp(1.0, i32(eu8) - 128); - var row_sum = 0.0; let q_packed = load_u32_at_src0(block_byte_base + 1u + 4u * thread_within_block); - for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { - let q_byte = get_byte(q_packed, byte_idx); - let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; - let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; - row_sum += q_lo * x_block[byte_idx]; - row_sum += q_hi * x_block[byte_idx + 4u]; + for (var col = 0u;col < NUM_COLS;col += 1) { + var row_sum = 0.0; + for (var byte_idx = 0u; byte_idx < 4u; byte_idx++) { + let q_byte = get_byte(q_packed, byte_idx); + let q_lo = f32(kvalues_mxfp4[q_byte & 0xFu]) * e; + let q_hi = f32(kvalues_mxfp4[(q_byte >> 4u) & 0xFu]) * e; + row_sum += q_lo * x_block[col][byte_idx]; + row_sum += q_hi * x_block[col][byte_idx + 4u]; + } + acc[col][row] += row_sum; } - acc[row] += row_sum; } } } diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl index 3ef2f77ebe..6ccaf61a6a 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec_q_acc.tmpl @@ -51,10 +51,7 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_0 #ifdef MUL_ACC_Q4_1 #define BLOCK_SIZE_BYTES 20 @@ -85,10 +82,7 @@ fn get_dm(block_byte_base: u32) -> vec2 { f32(load_f16_at_src0(block_byte_base + 2u)) ); } -fn mul_q8_1(row_sum: i32, dma: vec2, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (dma.x * b_ds.x) + dma.y * b_ds.y / THREADS_PER_BLOCK; -} -#endif +#endif // MUL_ACC_Q4_1 #ifdef MUL_ACC_Q8_0 #define BLOCK_SIZE_BYTES 34 @@ -111,46 +105,48 @@ fn repack_b_dm(block: u32) -> B_DS_TYPE { fn get_dm(block_byte_base: u32) -> f32 { return f32(load_f16_at_src0(block_byte_base)); } -fn mul_q8_1(row_sum: i32, da: f32, b_ds: B_DS_TYPE) -> f32 { - return f32(row_sum) * (da * b_ds); -} -#endif +#endif // MUL_ACC_Q8_0 -#ifdef LEGACY_QUANTS -fn mmvq_dot_product(a_byte_base: u32, b_inner_id: u32, b_repacked: vec2, b_ds: B_DS_TYPE) -> f32 { - var row_sum = 0; - let a_repacked = repack_a(a_byte_base, b_inner_id); - - row_sum += dot4I8Packed(a_repacked[0], b_repacked[0]); - row_sum += dot4I8Packed(a_repacked[1], b_repacked[1]); - - return mul_q8_1(row_sum, get_dm(a_byte_base), b_ds); -} - -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +#if defined(LEGACY_QUANTS) +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let num_blocks = params.k / BLOCK_SIZE; for (var block = thread_id / THREADS_PER_BLOCK; block < num_blocks; block += WG_SIZE / THREADS_PER_BLOCK) { - let b_inner_id = thread_id % THREADS_PER_BLOCK; - let b_block_idx = src1q_idx_base + block; - - let b_repacked = repack_b_qs(b_block_idx, b_inner_id); - let b_ds = repack_b_dm(b_block_idx); - + let inner_id = thread_id % THREADS_PER_BLOCK; for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, b_inner_id, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, inner_id); + let da = get_dm(block_byte_base); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + block; + let b_repacked = repack_b_qs(src1q_idx, inner_id); + let b_ds = repack_b_dm(src1q_idx); + + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]); + +#if defined(MUL_ACC_Q4_0) + acc[col][row] += f32(row_sum) * (da * b_ds.x) - 8.0 * da * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_0 + +#if defined(MUL_ACC_Q4_1) + acc[col][row] += f32(row_sum) * (da.x * b_ds.x) + da.y * b_ds.y / THREADS_PER_BLOCK; +#endif // MUL_ACC_Q4_1 + +#if defined(MUL_ACC_Q8_0) + acc[col][row] += f32(row_sum) * (da * b_ds); +#endif // MUL_ACC_Q8_0 + } } } } return acc; } -#endif +#endif // LEGACY_QUANTS #ifdef MUL_ACC_Q2_K #define BLOCK_SIZE_BYTES 84 @@ -191,22 +187,7 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { let scale = byte_of(load_u32_at_src0_aligned(scale_byte), scale_byte & 3u); return vec2(f32(scale & 0xFu), f32(scale >> 4u)); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let scale_q = i32(scale_min.x); - let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; - - let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) - + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; - let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) - + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); - - return b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); -} -#endif +#endif // MUL_ACC_Q2_K #ifdef MUL_ACC_Q4_K #define BLOCK_SIZE_BYTES 144 @@ -265,39 +246,52 @@ fn get_scale_min(block_byte_base: u32, tid: u32) -> vec2 { return vec2(scale, min_val); } -fn mmvq_dot_product(a_byte_base: u32, tid: u32, b_repacked: vec4, b_ds: B_DS_TYPE) -> f32 { - let a_repacked = repack_a(a_byte_base, tid); - let dm = get_dm(a_byte_base); - let scale_min = get_scale_min(a_byte_base, tid); - - let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) - + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); - - // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. - return b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); -} -#endif +#endif // MUL_ACC_Q4_K #ifdef K_QUANTS -fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array { - var acc: array; +fn accumulate_vec_q_dot(thread_id: u32, row_base: u32, src0_batch_offset: u32, src1q_idx_base: u32) -> array, NUM_COLS> { + var acc: array, NUM_COLS>; let tid = thread_id % THREADS_PER_BLOCK; for (var block = thread_id / THREADS_PER_BLOCK; block < params.k / BLOCK_SIZE; block += WG_SIZE / THREADS_PER_BLOCK) { - let src1q_idx = src1q_idx_base + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; - let b_repacked = repack_b_qs(src1q_idx, tid); - let b_ds = repack_b_dm(src1q_idx); - for (var row = 0u; row < OUTPUTS_PER_WG; row++) { let output_row = row_base + row; if (output_row < params.m) { let block_byte_base = (src0_batch_offset + output_row * params.stride_01 + block) * BLOCK_SIZE_BYTES; - acc[row] += mmvq_dot_product(block_byte_base, tid, b_repacked, b_ds); + let a_repacked = repack_a(block_byte_base, tid); + let dm = get_dm(block_byte_base); + let scale_min = get_scale_min(block_byte_base, tid); + for (var col = 0u;col < NUM_COLS;col += 1) { + let src1q_idx = src1q_idx_base + col * (params.k / Q8_BLOCK_SIZE) + (block * BLOCK_SIZE + ELEMS_PER_THREAD * tid) / Q8_BLOCK_SIZE; + let b_repacked = repack_b_qs(src1q_idx, tid); + let b_ds = repack_b_dm(src1q_idx); + +#if defined(MUL_ACC_Q2_K) + let scale_q = i32(scale_min.x); + let scale_m_i8x4 = u32(scale_min.y) * 0x01010101u; + + let row_sum_d = (dot4I8Packed(b_repacked[0], a_repacked[0]) + dot4I8Packed(b_repacked[1], a_repacked[1]) + + dot4I8Packed(b_repacked[2], a_repacked[2]) + dot4I8Packed(b_repacked[3], a_repacked[3])) * scale_q; + let row_sum_m = dot4I8Packed(b_repacked[0], scale_m_i8x4) + dot4I8Packed(b_repacked[1], scale_m_i8x4) + + dot4I8Packed(b_repacked[2], scale_m_i8x4) + dot4I8Packed(b_repacked[3], scale_m_i8x4); + + acc[col][row] += b_ds * (dm.x * f32(row_sum_d) - dm.y * f32(row_sum_m)); +#endif // MUL_ACC_Q2_K + +#if defined(MUL_ACC_Q4_K) + let row_sum = dot4I8Packed(a_repacked[0], b_repacked[0]) + dot4I8Packed(a_repacked[1], b_repacked[1]) + + dot4I8Packed(a_repacked[2], b_repacked[2]) + dot4I8Packed(a_repacked[3], b_repacked[3]); + + // Each thread covers half of the Q8_1 block, so add only b_ds.y/2. + acc[col][row] += b_ds.x * dm.x * scale_min.x * f32(row_sum) - dm.y * scale_min.y * (b_ds.y / (Q8_BLOCK_SIZE / ELEMS_PER_THREAD)); +#endif // MUL_ACC_Q4_K + + } } } } return acc; } -#endif +#endif // K_QUANTS diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl index b3f1fa04b8..847b27ffad 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/quantize_q8.wgsl @@ -9,9 +9,11 @@ requires packed_4x8_integer_dot_product; struct Params { offset_src1: u32, + stride_11: u32, stride_12: u32, stride_13: u32, ne0: u32, + ne1: u32, ne2: u32, ne3: u32, }; @@ -57,25 +59,28 @@ fn main( @builtin(num_workgroups) num_wg: vec3 ) { let thread_id = local_id.x; - let num_vec4 = params.ne0 / 4u; + let ne0_vec4 = params.ne0 / 4u; - let wg_per_vec = (num_vec4 + (WG_SIZE - 1u)) / WG_SIZE; - let total_batches = wg_per_vec * params.ne2 * params.ne3; + let wg_per_vec = (ne0_vec4 + (WG_SIZE - 1u)) / WG_SIZE; + let total_batches = wg_per_vec * params.ne1 * params.ne2 * params.ne3; let wg_linear = wg_id.y * num_wg.x + wg_id.x; if (wg_linear >= total_batches) { return; } - let src13_idx = wg_linear / (params.ne2 * wg_per_vec); - let src12_idx = (wg_linear - src13_idx * (params.ne2 * wg_per_vec)) / wg_per_vec; - let src11_wg_idx = wg_linear % wg_per_vec; - let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12; + let vec_idx = wg_linear / wg_per_vec; + let src13_idx = vec_idx / (params.ne2 * params.ne1); + let vec_ne12_num = vec_idx % (params.ne2 * params.ne1); + let src12_idx = vec_ne12_num / params.ne1; + let src11_idx = vec_ne12_num % params.ne1; + let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + src11_idx * params.stride_11; let src1_idx_vec4_base = src1_idx_base / 4u; let blocks_per_row = params.ne0 / 32u; let blocks_per_wg = (WG_SIZE * 4u) / 32u; - let src1q_idx_base = (src13_idx * params.ne2 + src12_idx) * blocks_per_row; + let src1q_idx_base = ((src13_idx * params.ne2 + src12_idx) * params.ne1 + src11_idx) * blocks_per_row; + let src11_wg_idx = wg_linear % wg_per_vec; let src1q_idx = src1q_idx_base + src11_wg_idx * blocks_per_wg + thread_id / 8u; let qs_idx = thread_id % 8u; @@ -85,7 +90,7 @@ fn main( var thread_amax = 0.0; let src11_vec4_idx = src11_wg_idx * WG_SIZE + thread_id; - let is_valid = src11_vec4_idx < num_vec4; + let is_valid = src11_vec4_idx < ne0_vec4; #ifdef USE_SUBGROUP_REDUCTION diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 463963f2ac..1bda9452dd 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -359,6 +359,7 @@ class Keys: CHUNK_SIZE = "clip.audio.chunk_size" CONV_KERNEL_SIZE = "clip.audio.conv_kernel_size" MAX_POS_EMB = "clip.audio.max_pos_emb" + FEATURE_LAYERS = "clip.audio.feature_layer" # Granite Speech Plus class Attention: HEAD_COUNT = "clip.audio.attention.head_count" diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index f707f29dc5..a06ec88b32 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -1310,6 +1310,9 @@ class GGUFWriter: def add_audio_max_pos_emb(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.MAX_POS_EMB, value) + def add_audio_feature_layers(self, layers: Sequence[int]) -> None: + self.add_array(Keys.ClipAudio.FEATURE_LAYERS, layers) + def add_audio_projector_window_size(self, value: int) -> None: self.add_uint32(Keys.ClipAudio.Projector.WINDOW_SIZE, value) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 15ae38927c..127c4634c0 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -8433,6 +8433,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {1, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {3, 2}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {3, 2}, {2, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {1, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, k, {1, 1}, {2, 1})); @@ -8449,6 +8450,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 4, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 2, 1, 3})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 1, 3, 2})); test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 8, k, {2, 3}, {1, 1}, {0, 3, 2, 1})); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 30aa35e137..c38aed8cfe 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -1562,37 +1562,112 @@ static void test_msgs_oaicompat_json_conversion() { } } -static void test_split_by_role() { +static void test_msg_token_delimiters_split() { LOG_DBG("%s\n", __func__); + // Delimiters that share a leading token, distinguished by the second token, + // to exercise the per-position token matching. + const common_chat_msg_delimiters delims = { + { { COMMON_CHAT_ROLE_USER, "", { 10, 11 } }, + { COMMON_CHAT_ROLE_ASSISTANT, "", { 10, 12 } } } + }; + // Empty inputs - assert_equals(0, common_chat_split_by_role("", {}).size()); - assert_equals(0, common_chat_split_by_role("hello", {}).size()); - assert_equals(0, common_chat_split_by_role("", { { "user", "<|user|>" } }).size()); + assert_equals(0, common_chat_msg_delimiters{}.split({}).spans.size()); + assert_equals(0, common_chat_msg_delimiters{}.split({ 10, 11 }).spans.size()); + assert_equals(0, delims.split({}).spans.size()); - // Multi-role conversation, no leading/trailing content + // No delimiters match -> no spans + assert_equals(0, delims.split({ 100, 101, 102 }).spans.size()); + + // Multi-role conversation: HiHelloBye { - const std::string prompt = "<|user|>Hi<|assistant|>Hello<|user|>Bye"; - const auto splits = common_chat_split_by_role(prompt, { - { "user", "<|user|>" }, - { "assistant", "<|assistant|>" }, - }); - assert_equals(3, splits.size()); + const llama_tokens tokens = { + 10, 11, // + 100, 101, // Hi + 10, 12, // + 200, 201, 202, // Hello + 10, 11, // + 300, 301, // Bye + }; - assert_equals("user", splits[0].role); - assert_equals(0, splits[0].pos); - assert_equals(10, splits[0].len); - assert_equals("<|user|>Hi", prompt.substr(splits[0].pos, splits[0].len)); + const auto result = delims.split(tokens); + const auto & spans = result.spans; + assert_equals(3, spans.size()); - assert_equals("assistant", splits[1].role); - assert_equals(10, splits[1].pos); - assert_equals(18, splits[1].len); - assert_equals("<|assistant|>Hello", prompt.substr(splits[1].pos, splits[1].len)); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(4, spans[0].len); - assert_equals("user", splits[2].role); - assert_equals(28, splits[2].pos); - assert_equals(11, splits[2].len); - assert_equals("<|user|>Bye", prompt.substr(splits[2].pos, splits[2].len)); + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(4, spans[1].pos); + assert_equals(5, spans[1].len); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[2].role); + assert_equals(9, spans[2].pos); + assert_equals(4, spans[2].len); + + // is_user_start() is true at the token position where a user span begins + assert_equals(true, result.is_user_start(0)); + assert_equals(false, result.is_user_start(4)); // assistant span + assert_equals(true, result.is_user_start(9)); + } + + // Content before the first delimiter is not captured as a span + { + const llama_tokens tokens = { + 500, 501, // leading content (dropped) + 10, 11, // + 100, // Hi + }; + + const auto spans = delims.split(tokens).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(2, spans[0].pos); + assert_equals(3, spans[0].len); + } + + // Skipped regions (media chunks) are jumped over but still count as span content + { + const llama_tokens tokens = { + 10, 11, // + LLAMA_TOKEN_NULL, // media chunk (3 tokens) + LLAMA_TOKEN_NULL, + LLAMA_TOKEN_NULL, + 100, // Hi + 10, 12, // + }; + + const std::map skips = { { 2, 3 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(2, spans.size()); + + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(6, spans[0].len); + + assert_equals(COMMON_CHAT_ROLE_ASSISTANT, spans[1].role); + assert_equals(6, spans[1].pos); + assert_equals(2, spans[1].len); + } + + // A delimiter sequence inside a skipped region is not matched + { + const llama_tokens tokens = { + 10, 11, // + 10, 12, // skipped region that happens to contain delimiter tokens + 100, // Hi + }; + + const std::map skips = { { 2, 2 } }; + + const auto spans = delims.split(tokens, skips).spans; + assert_equals(1, spans.size()); + assert_equals(COMMON_CHAT_ROLE_USER, spans[0].role); + assert_equals(0, spans[0].pos); + assert_equals(5, spans[0].len); } } @@ -5857,7 +5932,7 @@ int main(int argc, char ** argv) { { test_msg_diffs_compute(); test_msgs_oaicompat_json_conversion(); - test_split_by_role(); + test_msg_token_delimiters_split(); test_tools_oaicompat_json_conversion(); test_convert_responses_to_chatcmpl(); test_developer_role_to_system_workaround(); diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index e7b5301445..5b413681f0 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -42,6 +42,7 @@ #define KEY_N_HEAD "clip.%s.attention.head_count" #define KEY_N_HEAD_KV "clip.%s.attention.head_count_kv" #define KEY_LAYER_NORM_EPS "clip.%s.attention.layer_norm_epsilon" +#define KEY_FEATURE_LAYERS "clip.%s.feature_layer" // vision-specific #define KEY_VISION_PROJ_TYPE "clip.vision.projector_type" // for models with mixed modalities @@ -54,7 +55,6 @@ #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" -#define KEY_FEATURE_LAYER "clip.vision.feature_layer" #define KEY_PROJ_SCALE_FACTOR "clip.vision.projector.scale_factor" #define KEY_PROJ_SAMPLE_QUERY_SIDE "clip.vision.projector.query_side" #define KEY_PROJ_SAMPLE_WINDOW_SIDE "clip.vision.projector.window_side" diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 48796b6306..f86702eba4 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -91,7 +91,7 @@ struct clip_hparams { float eps = 1e-6; float rope_theta = 0.0; - std::vector vision_feature_layer; + std::vector feature_layers; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) @@ -165,8 +165,8 @@ struct clip_hparams { return false; } - bool is_vision_feature_layer(int32_t layer) const { - return std::find(vision_feature_layer.begin(), vision_feature_layer.end(), layer) != vision_feature_layer.end(); + bool is_feature_layer(int32_t layer) const { + return std::find(feature_layers.begin(), feature_layers.end(), layer) != feature_layers.end(); } }; diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 7dd7023c41..7bd486030f 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -1264,12 +1264,10 @@ struct clip_model_loader { } } - // Load the vision feature layer indices if they are explicitly provided; - // if multiple vision feature layers are present, the values will be concatenated - // to form the final visual features. + // Load the vision/audio feature layer indices if they are explicitly provided // NOTE: gguf conversions should standardize the values of the vision feature layer to // be non-negative, since we use -1 to mark values as unset here. - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer, false); + get_arr_int(string_format(KEY_FEATURE_LAYERS, prefix), hparams.feature_layers, false); // model-specific params switch (model.proj_type) { @@ -1651,6 +1649,7 @@ struct clip_model_loader { get_u32(KEY_A_PROJ_WINDOW_SIZE, hparams.audio_proj_window_size); get_u32(KEY_A_PROJ_DOWNSAMPLE_RATE, hparams.audio_proj_downsample_rate); get_u32(KEY_A_PROJ_HEAD_COUNT, hparams.audio_proj_head_count); + // NOTE: feature layers loaded above in common path } break; case PROJECTOR_TYPE_JANUS_PRO: { @@ -1663,11 +1662,11 @@ struct clip_model_loader { hparams.image_resize_algo = RESIZE_ALGO_BICUBIC_PILLOW; hparams.image_resize_pad = PAD_CEIL; - get_arr_int(KEY_FEATURE_LAYER, hparams.vision_feature_layer); + // NOTE: feature_layers loaded in common path as optional get_arr_int(KEY_PROJ_SPATIAL_OFFSETS, hparams.proj_spatial_offsets); - if (hparams.vision_feature_layer.size() != hparams.proj_spatial_offsets.size()) { - throw std::runtime_error(string_format("%s: vision_feature_layer.size() %d != proj_spatial_offsets.size() %d", - hparams.vision_feature_layer.size(), hparams.proj_spatial_offsets.size())); + if (hparams.feature_layers.size() != hparams.proj_spatial_offsets.size()) { + throw std::runtime_error(string_format("%s: feature_layers.size() %d != proj_spatial_offsets.size() %d", + hparams.feature_layers.size(), hparams.proj_spatial_offsets.size())); } get_u32(KEY_PROJ_SAMPLE_QUERY_SIDE, hparams.downsample_query_side); @@ -2740,7 +2739,7 @@ struct clip_model_loader { model.image_newline = get_tensor(TN_IMAGE_NEWLINE); // Load separate layerwise and spatial projector tensors - const auto projector_count = hparams.vision_feature_layer.size(); + const auto projector_count = hparams.feature_layers.size(); model.qf_proj_blocks.resize(projector_count); for (size_t bid = 0; bid < projector_count; ++bid) { auto & b = model.qf_proj_blocks[bid]; @@ -4388,7 +4387,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, int n_threads, const clip_image_f32 // Stage 1b only uses block 0's permutations; future stages // will upload all blocks. - for (size_t bid = 0; bid < hparams.vision_feature_layer.size(); ++bid) { + for (size_t bid = 0; bid < hparams.feature_layers.size(); ++bid) { const std::string prefix = "g4v_blk" + std::to_string(bid) + "_"; upload(prefix + "win_idx", make_win_idx(image_side, window_side)); upload(prefix + "qwin_idx", make_win_idx(new_side, query_side)); diff --git a/tools/mtmd/models/granite-speech.cpp b/tools/mtmd/models/granite-speech.cpp index 0bd4d75ac5..a158a59ce9 100644 --- a/tools/mtmd/models/granite-speech.cpp +++ b/tools/mtmd/models/granite-speech.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include + ggml_cgraph * clip_graph_granite_speech::build() { const int n_frames = img.nx(); const int context_size = hparams.audio_chunk_size; @@ -11,6 +13,10 @@ ggml_cgraph * clip_graph_granite_speech::build() { const int padded_len = num_blocks * context_size; const int remainder = n_frames % context_size; + // Calculate projector input dimension based on feature layers + const int proj_input_dim = n_embd * (hparams.feature_layers.size() + 1); + const bool use_feature_concat = !hparams.feature_layers.empty(); + ggml_tensor * attn_dists = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, context_size * context_size); ggml_set_name(attn_dists, "attn_dists"); ggml_set_input(attn_dists); @@ -31,6 +37,15 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_add(ctx0, cur, model.inp_proj_b); cb(cur, "inp_linear", -1); + // Capture layer 0 if requested (after input_linear) + ggml_tensor * concat_result = nullptr; + if (use_feature_concat) { + if (std::find(hparams.feature_layers.begin(), hparams.feature_layers.end(), 0) != hparams.feature_layers.end()) { + concat_result = cur; + cb(concat_result, "feature_layer_0", -1); + } + } + for (int il = 0; il < n_layer; il++) { const auto & layer = model.layers[il]; auto * residual = cur; @@ -168,6 +183,18 @@ ggml_cgraph * clip_graph_granite_speech::build() { NORM_TYPE_NORMAL, eps, il); cb(cur, "layer_out", il); + // Capture intermediate layer (il + 1) if requested + if (use_feature_concat) { + if (hparams.is_feature_layer(il + 1)) { + if (concat_result == nullptr) { + concat_result = cur; + } else { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + } + cb(concat_result, string_format("feature_layer_%d", il + 1).c_str(), il); + } + } + // CTC branch if (il + 1 == ctc_layer) { auto * mid = build_mm(model.ctc_out_w, cur); @@ -180,6 +207,13 @@ ggml_cgraph * clip_graph_granite_speech::build() { } } + // Append final output to concatenated features if using feature concatenation + if (use_feature_concat && concat_result != nullptr) { + concat_result = ggml_concat(ctx0, concat_result, cur, 0); + cb(concat_result, "concat_final", -1); + cur = concat_result; + } + cb(cur, "encoder_out", -1); // QFormer projector @@ -197,7 +231,7 @@ ggml_cgraph * clip_graph_granite_speech::build() { cur = ggml_pad(ctx0, cur, 0, padded_proj - n_frames, 0, 0); } - ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, n_embd, window_size, nblocks_proj); + ggml_tensor * enc_windows = ggml_reshape_3d(ctx0, cur, proj_input_dim, window_size, nblocks_proj); ggml_tensor * queries = build_norm(model.qf_proj_blocks[0].qf_proj_query, model.qf_proj_blocks[0].qf_proj_norm_w, model.qf_proj_blocks[0].qf_proj_norm_b, diff --git a/tools/mtmd/models/granite4-vision.cpp b/tools/mtmd/models/granite4-vision.cpp index 9adb6f0fdb..1b252543c0 100644 --- a/tools/mtmd/models/granite4-vision.cpp +++ b/tools/mtmd/models/granite4-vision.cpp @@ -304,14 +304,14 @@ ggml_cgraph * clip_graph_granite4_vision::build() { } // --- Stage 1b/1c: WindowQFormer blocks --- - const int projector_count = hparams.vision_feature_layer.size(); + const int projector_count = hparams.feature_layers.size(); const float qformer_eps = 1e-12f; ggml_tensor * mmproj = nullptr; for (int bid = 0; bid < projector_count; ++bid) { const auto & blk = model.qf_proj_blocks[bid]; - int vlayer = hparams.vision_feature_layer[bid]; + int vlayer = hparams.feature_layers[bid]; GGML_ASSERT(vlayer >= 0 && vlayer < n_layer); ggml_tensor * h = layer_outs[vlayer]; diff --git a/tools/mtmd/models/llava.cpp b/tools/mtmd/models/llava.cpp index 5aa3d2f0fa..47efe68bd8 100644 --- a/tools/mtmd/models/llava.cpp +++ b/tools/mtmd/models/llava.cpp @@ -21,7 +21,7 @@ ggml_cgraph * clip_graph_llava::build() { // If we set explicit vision feature layers, only go up to the deepest one // NOTE: only used by granite-vision models for now - for (const auto & feature_layer : hparams.vision_feature_layer) { + for (const auto & feature_layer : hparams.feature_layers) { if (feature_layer > deepest_feature_layer) { deepest_feature_layer = feature_layer; } @@ -59,7 +59,7 @@ ggml_cgraph * clip_graph_llava::build() { // If this is an embedding feature layer, save the output. // NOTE: 0 index here refers to the input to the encoder. - if (hparams.is_vision_feature_layer(il)) { + if (hparams.is_feature_layer(il)) { embedding_stack.push_back(cur); } @@ -134,7 +134,7 @@ ggml_cgraph * clip_graph_llava::build() { // process vision feature layers (used by granite) { // final layer is a vision feature layer - if (hparams.is_vision_feature_layer(max_feature_layer)) { + if (hparams.is_feature_layer(max_feature_layer)) { embedding_stack.push_back(inpL); } diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e412b94c5c..ac291d359a 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -518,6 +518,14 @@ size_t server_tokens::get_common_prefix(const server_tokens & b) const { return max_idx; // all tokens are equal } +common_chat_msg_spans server_tokens::find_message_spans(const common_chat_msg_delimiters & delims) const { + std::map skips; + for (const auto & it : map_idx_to_media) { + skips[it.first] = mtmd_input_chunk_get_n_tokens(it.second.get()); + } + return delims.split(tokens, skips); +} + bool server_tokens::validate(const struct llama_context * ctx) const { const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -1104,15 +1112,7 @@ json oaicompat_chat_params_parse( llama_params["chat_parser"] = chat_params.parser; } - llama_params["message_spans"] = json::array(); - - for (const auto & span : chat_params.message_spans) { - llama_params["message_spans"].push_back({ - { "role", span.role }, - { "pos", span.pos }, - { "len", span.len }, - }); - } + llama_params["message_delimiters"] = chat_params.message_delimiters.to_json(); // Reasoning budget: pass parameters through to sampling layer { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index efd31733b0..c0eaec6b02 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -218,6 +218,9 @@ public: size_t get_common_prefix(const server_tokens & b) const; + // split the tokens into message spans, skipping over media chunks + common_chat_msg_spans find_message_spans(const common_chat_msg_delimiters & delims) const; + // make sure all text tokens are within the vocab range bool validate(const struct llama_context * ctx) const; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 0a25b414ed..ca91449d26 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -3436,8 +3436,8 @@ private: has_mtmd = true; } - const int32_t n_before_user = slot.task->params.n_before_user; - const bool n_before_user_known = n_before_user > 0; + const auto & spans = slot.task->params.message_spans; + const auto last_user_pos = spans.last_user_message_pos(); // add prompt tokens for processing in the current batch while (slot.prompt.n_tokens() < slot.task->n_tokens() && batch.size() < n_batch) { @@ -3466,10 +3466,8 @@ private: slot.n_prompt_tokens_processed++; - // stop the prompt batch exactly before the latest user input, so a checkpoint - // can be created after the previous messages - if (n_before_user_known && - slot.prompt.n_tokens() == n_before_user) { + // stop the prompt batch exactly before a user message + if (spans.is_user_start(slot.prompt.n_tokens())) { break; } @@ -3498,8 +3496,13 @@ private: // the number of tokens added to the batch for the current slot const auto n_tokens_cur = batch.size() - n_tokens_prev; + const auto n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; + const bool near_prompt_end = slot.task->n_tokens() < slot.prompt.n_tokens() + n_ubatch; + const bool is_user_start = spans.is_user_start(n_tokens_start); + const bool is_last_user_message = n_tokens_start == last_user_pos; + // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -3514,8 +3517,9 @@ private: slot.init_sampler(); } else { - // skip ordinary mid-prompt checkpoints - if (!n_before_user_known && !near_prompt_end) { + // skip ordinary mid-prompt checkpoints, unless the batch starts a user + // message or we are near the end of the prompt + if (!is_user_start && !near_prompt_end) { do_checkpoint = false; } } @@ -3523,29 +3527,6 @@ private: const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx_tgt), slot.id); const auto pos_max = llama_memory_seq_pos_max(llama_get_memory(ctx_tgt), slot.id); - // checkpoints are created before the current batch is decoded, so - // their token position is the batch start rather than the prompt end - const int32_t n_tokens_start = slot.prompt.n_tokens() - n_tokens_cur; - - { - const bool is_on_user = - n_before_user_known && - n_tokens_start == n_before_user; - - const bool is_after_user = - n_before_user_known && - n_tokens_start > n_before_user; - - const bool is_allowed = - !n_before_user_known || - is_on_user || - (is_after_user && near_prompt_end); - - if (do_checkpoint && !is_allowed) { - do_checkpoint = false; - } - } - // nothing to checkpoint yet // TODO: is this check needed? if (do_checkpoint && pos_min < 0) { @@ -3555,8 +3536,8 @@ private: // do not checkpoint after mtmd chunks do_checkpoint = do_checkpoint && !has_mtmd; - // no need to create checkpoints that are too close together - do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); + // no need to create checkpoints that are too close together, unless it's the last user message + do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || is_last_user_message || n_tokens_start > slot.prompt.checkpoints.back().n_tokens + params_base.checkpoint_min_step); SLT_DBG(slot, "main/do_checkpoint = %s, pos_min = %d, pos_max = %d\n", do_checkpoint ? "yes" : "no", pos_min, pos_max); // note: we create the checkpoint before calling llama_decode(), so the current batch is not @@ -4055,54 +4036,6 @@ void server_context::set_state_callback(server_state_callback_t callback) { }); } -// compute the number of tokens before the last user message in the prompt -static int32_t prompt_get_n_before_user( - const json & message_spans, - const std::string & prompt, - const std::vector & files, - const llama_vocab * vocab, - mtmd_context * mctx) { - int32_t result = -1; - int32_t byte_pos = -1; - - for (const auto & span : message_spans) { - const std::string role = json_value(span, "role", std::string()); - - if (role == "user") { - byte_pos = json_value(span, "pos", -1); - } - } - - if (byte_pos >= 0) { - GGML_ASSERT((size_t) byte_pos <= prompt.size()); - - const std::string prefix = prompt.substr(0, (size_t) byte_pos); - - const std::string marker = get_media_marker(); - size_t n_prefix_media = 0; - for (size_t pos = 0; (pos = prefix.find(marker, pos)) != std::string::npos; pos += marker.size()) { - n_prefix_media++; - } - - GGML_ASSERT(n_prefix_media <= files.size()); - - if (mctx != nullptr && n_prefix_media > 0) { - // TODO: this makes a copy - avoid it - std::vector prefix_files(files.begin(), files.begin() + n_prefix_media); - - result = (int32_t) process_mtmd_prompt(mctx, prefix, prefix_files).size(); - } else { - result = (int32_t) tokenize_input_prompts(vocab, nullptr, prefix, true, true)[0].size(); - } - - SRV_TRC("message_spans: last user message: byte_pos=%d, media=%zu, n_before_user=%d\n", - byte_pos, n_prefix_media, result); - } - - return result; -} - - // // server_routes // @@ -4150,6 +4083,10 @@ std::unique_ptr server_routes::handle_completions_impl( // tasks.reserve(inputs.size()); // TODO: this is inaccurate due to child tasks + // message delimiters for checkpointing + auto delimiters = common_chat_msg_delimiters_parse(json_value(data, "message_delimiters", json::array())); + delimiters.tokenize(ctx_server.vocab); + for (size_t i = 0; i < inputs.size(); i++) { server_task task = server_task(type); @@ -4163,16 +4100,7 @@ std::unique_ptr server_routes::handle_completions_impl( meta->logit_bias_eog, data); - const auto message_spans = json_value(data, "message_spans", json::array()); - if (prompt.is_string() && message_spans.is_array()) { - task.params.n_before_user = - prompt_get_n_before_user( - message_spans, - prompt.get(), - files, - ctx_server.vocab, - ctx_server.mctx); - } + task.params.message_spans = task.tokens.find_message_spans(delimiters); task.id_slot = json_value(data, "id_slot", -1); diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index cf0bc845ea..b1513c9fe6 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -218,7 +218,7 @@ void server_model_meta::update_caps() { }); params.offline = true; // params.skip_download = true; // TODO: ideally, we should validate the model here, but it takes too much time - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); if (params.mmproj.path.empty()) { multimodal = { false, false }; } else { @@ -1327,7 +1327,9 @@ struct server_download_state : public common_download_callback { bool run(common_params & params) { try { - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, this); + common_params_handle_models_params p; + p.callback = this; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, p); is_ok = true; } catch (const std::exception & e) { auto model_name = params.model.get_name(); diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 299c279d7d..293bdf053a 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -62,9 +62,6 @@ struct task_params { int32_t n_cache_reuse = 0; // min chunk size to attempt reusing from the cache via KV shifting (0 = disabled) - // number of prompt tokens before the latest user message - int32_t n_before_user = -1; - int64_t t_max_prompt_ms = -1; // TODO: implement int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit @@ -92,6 +89,9 @@ struct task_params { // per-request parameters for chat parsing common_chat_parser_params chat_parser_params; + // message spans for checkpointing + common_chat_msg_spans message_spans; + // Embeddings int32_t embd_normalize = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) diff --git a/tools/server/server.cpp b/tools/server/server.cpp index b5902458c8..a101df655d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -96,6 +96,17 @@ int llama_server(int argc, char ** argv) { } int llama_server(common_params & params, int argc, char ** argv) { + // note: router mode also accepts -hf remote-preset, so we need to check that first + if (!params.model.hf_repo.empty()) { + try { + common_params_handle_models_params handle_params; + handle_params.preset_only = true; + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, handle_params); + } catch (const std::exception & e) { + // ignored for now + } + } + // router server never loads a model and must not touch the GPU const bool is_router_server = params.model.path.empty() && params.model.hf_repo.empty(); @@ -270,7 +281,7 @@ int llama_server(common_params & params, int argc, char ** argv) { return child.run_download(params); } else if (!is_router_server) { // single-model mode (NOT spawned by router) - common_params_handle_models(params, LLAMA_EXAMPLE_SERVER); + common_params_handle_models(params, LLAMA_EXAMPLE_SERVER, {}); } // diff --git a/tools/server/tests/unit/test_router.py b/tools/server/tests/unit/test_router.py index 41e95f4c5f..94165e520e 100644 --- a/tools/server/tests/unit/test_router.py +++ b/tools/server/tests/unit/test_router.py @@ -256,6 +256,25 @@ def test_router_reload_models(): os.remove(preset_path) +def test_router_remote_preset(): + global server + server.model_hf_repo = "ggml-org/test-preset-ci" + server.model_hf_file = None + server.offline = False + server.start() + + # Should see preset models in GET /models + res = server.make_request("GET", "/models") + assert res.status_code == 200 + ids = {item["id"] for item in res.body.get("data", [])} + assert "tinygemma3-preset" in ids + assert "stories260K-test" in ids + + # Should be able to load a preset model + model_id = "tinygemma3-preset" + _load_model_and_wait(model_id) + + MODEL_DOWNLOAD_ID = "ggml-org/test-model-router-download:F16" MODEL_DOWNLOAD_TIMEOUT = 30