diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index ed4a6b13bb..6a2eb8c824 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -98,6 +98,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } #endif // INIT_SRC0_SHMEM_Q1_0 +// legacy-quants #if defined(INIT_SRC0_SHMEM_Q4_0) || defined(INIT_SRC0_SHMEM_Q4_1) || defined(INIT_SRC0_SHMEM_Q5_0) || defined(INIT_SRC0_SHMEM_Q5_1) || defined(INIT_SRC0_SHMEM_Q8_0) || defined(INIT_SRC0_SHMEM_Q8_1) || defined(INIT_SRC0_SHMEM_MXFP4) const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. @@ -124,7 +125,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 if (global_m < params.m && global_block_k < params.k / BLOCK_SIZE) { let src0_idx = batch_offset + global_m * params.stride_01 + global_block_k; -#ifdef INIT_SRC0_SHMEM_Q4_0 +#if defined(INIT_SRC0_SHMEM_Q4_0) let block_byte_base = src0_idx * 18u; // BLOCK_SIZE_BYTES = 18u; let d = load_f16_at_src0(block_byte_base); @@ -134,7 +135,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let q_packed = load_u32_at_src0(q_byte_offset); dequant_q4_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } -#elif INIT_SRC0_SHMEM_Q4_1 +#endif // INIT_SRC0_SHMEM_Q4_0 + +#if defined(INIT_SRC0_SHMEM_Q4_1) let block_byte_base = src0_idx * 20u; // BLOCK_SIZE_BYTES = 20u; let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); let d = f16(dm[0]); @@ -153,7 +156,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } -#elif INIT_SRC0_SHMEM_Q5_0 +#endif // INIT_SRC0_SHMEM_Q4_1 + +#if defined(INIT_SRC0_SHMEM_Q5_0) let block_byte_base = src0_idx * 22u; // BLOCK_SIZE_BYTES = 22u; let d = load_f16_at_src0(block_byte_base); @@ -176,7 +181,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } -#elif INIT_SRC0_SHMEM_Q5_1 +#endif // INIT_SRC0_SHMEM_Q5_0 + +#if defined(INIT_SRC0_SHMEM_Q5_1) let block_byte_base = src0_idx * 24u; // BLOCK_SIZE_BYTES = 24u; let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); @@ -201,7 +208,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = q_hi; } } -#elif INIT_SRC0_SHMEM_Q8_0 +#endif // INIT_SRC0_SHMEM_Q5_1 + +#if defined(INIT_SRC0_SHMEM_Q8_0) let block_byte_base = src0_idx * 34u; // BLOCK_SIZE_BYTES = 34u; let d = load_f16_at_src0(block_byte_base); @@ -211,7 +220,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let q_packed = load_u32_at_src0(q_byte_offset); dequant_q8_0_packed_to_shmem(q_packed, d, shmem_idx + j * BYTES_PER_INNER_LOOP); } -#elif INIT_SRC0_SHMEM_Q8_1 +#endif // INIT_SRC0_SHMEM_Q8_0 + +#if defined(INIT_SRC0_SHMEM_Q8_1) let block_byte_base = src0_idx * 36u; // BLOCK_SIZE_BYTES = 36u; let dm = unpack2x16float(load_u32_at_src0_aligned(block_byte_base)); let d = f16(dm[0]); @@ -227,7 +238,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k] = q_val; } } -#elif INIT_SRC0_SHMEM_MXFP4 +#endif // INIT_SRC0_SHMEM_Q8_1 + +#if defined(INIT_SRC0_SHMEM_MXFP4) let block_byte_base = src0_idx * 17u; let eu8 = get_byte(load_u32_at_src0_aligned(block_byte_base), block_byte_base & 3u); let e = ldexp(1.0, i32(eu8) - 128); @@ -244,11 +257,11 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 shmem[shmem_idx + j * BYTES_PER_INNER_LOOP + k + 16u] = f16(q_hi); } } -#endif +#endif // INIT_SRC0_SHMEM_MXFP4 } } } -#endif +#endif // legacy-quants // k-quants #if defined(INIT_SRC0_SHMEM_Q2_K) || defined(INIT_SRC0_SHMEM_Q3_K) || defined(INIT_SRC0_SHMEM_Q4_K) || defined(INIT_SRC0_SHMEM_Q5_K) || defined(INIT_SRC0_SHMEM_Q6_K) @@ -284,7 +297,7 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let src0_idx = batch_offset + global_m * params.stride_01 + block_k; -#ifdef INIT_SRC0_SHMEM_Q2_K +#if defined(INIT_SRC0_SHMEM_Q2_K) let block_byte_base = src0_idx * 84u; // BLOCK_SIZE_BYTES = 84u; let scales_byte_base = block_byte_base; let qs_byte_base = block_byte_base + 16u; @@ -314,7 +327,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(scale >> 4u); store_shmem_kquants(qs_vec4 * dl - ml, elem_idx); -#elif INIT_SRC0_SHMEM_Q3_K +#endif // INIT_SRC0_SHMEM_Q2_K + +#if defined(INIT_SRC0_SHMEM_Q3_K) let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; let hmask_byte_base = block_byte_base + 0u; let qs_byte_base = block_byte_base + 32u; @@ -355,7 +370,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let dl = d_all * (f16((scale_hi2 << 4u) | scale_low4) - 32.0); store_shmem_kquants(dl * q_vec4, elem_idx); -#elif INIT_SRC0_SHMEM_Q4_K +#endif // INIT_SRC0_SHMEM_Q3_K + +#if defined(INIT_SRC0_SHMEM_Q4_K) let block_byte_base = src0_idx * 144u; // BLOCK_SIZE_BYTES = 144u; let dm_byte_base = block_byte_base + 0u; let scale_byte_base = block_byte_base + 4u; @@ -399,7 +416,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); store_shmem_kquants(dl * qs_vec4 - vec4(ml, ml, ml, ml), elem_idx); -#elif INIT_SRC0_SHMEM_Q5_K +#endif // INIT_SRC0_SHMEM_Q4_K + +#if defined(INIT_SRC0_SHMEM_Q5_K) let block_byte_base = src0_idx * 176u; // BLOCK_SIZE_BYTES = 176u; let dm_byte_base = block_byte_base + 0u; let scale_byte_base = block_byte_base + 4u; @@ -456,7 +475,9 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let ml = dmin * f16(mn); store_shmem_kquants((qh_vec4 + qs_lo4_vec4) * dl - vec4(ml, ml, ml, ml), elem_idx); -#elif INIT_SRC0_SHMEM_Q6_K +#endif // INIT_SRC0_SHMEM_Q5_K + +#if defined(INIT_SRC0_SHMEM_Q6_K) let block_byte_base = src0_idx * 210u; // BLOCK_SIZE_BYTES = 210u; let ql_byte_base = block_byte_base; let qh_byte_base = block_byte_base + 128u; @@ -497,17 +518,18 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 let scale = get_byte_i32(scale_word, scale_byte & 3u); store_shmem_kquants(d * q_vec4 * f16(scale), elem_idx); -#endif +#endif // INIT_SRC0_SHMEM_Q6_K } } #endif // k-quants -#ifdef INIT_SRC0_SHMEM_IQ4_NL +#if defined(INIT_SRC0_SHMEM_IQ4_NL) const BLOCK_SIZE = 32u; const BLOCK_SIZE_BYTES = 18u; +const NQ = 4u; fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; @@ -519,408 +541,464 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 4 == 0; + + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_at_src0(block_byte_base); + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; - let pos = k_in_block % 16u; - let nib_shift = (k_in_block / 16u) * 4u; - let q_packed = load_u32_at_src0(block_byte_base + 2u + (pos / 4u) * 4u); - let nib = (get_byte(q_packed, pos % 4u) >> nib_shift) & 0xFu; + let d = load_f16_at_src0(d_byte_base); - shmem[elem_idx] = d * f16(kvalues_iq4nl[nib]); + let id_qtr = (k_in_block % 16u) / 4u; + let shift_phase = k_in_block / 16u; + + let qs_u32 = load_u32_at_src0(qs_byte_base + 4u * id_qtr); + + shmem[elem_idx + 0u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 0u + 4u * shift_phase)) & 0xFu]); + shmem[elem_idx + 1u] = d * f16(kvalues_iq4nl[(qs_u32 >> ( 8u + 4u * shift_phase)) & 0xFu]); + shmem[elem_idx + 2u] = d * f16(kvalues_iq4nl[(qs_u32 >> (16u + 4u * shift_phase)) & 0xFu]); + shmem[elem_idx + 3u] = d * f16(kvalues_iq4nl[(qs_u32 >> (24u + 4u * shift_phase)) & 0xFu]); } } #endif // INIT_SRC0_SHMEM_IQ4_NL -#ifdef INIT_SRC0_SHMEM_IQ4_XS +// i-quants (super block size: 256) +#if defined(INIT_SRC0_SHMEM_IQ4_XS) || defined(INIT_SRC0_SHMEM_IQ1_S) || defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ2_XXS) \ +|| defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) || defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S) const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 136u; +const NQ = 16u; + +fn store_shmem_iquants(val: vec4, idx: u32) { + shmem[idx] = val.x; + shmem[idx + 1] = val.y; + shmem[idx + 2] = val.z; + shmem[idx + 3] = val.w; +} + +fn load_byte_at_src0_aligned(byte_offset: u32) -> u32 { + return get_byte(load_u32_at_src0_aligned(byte_offset), byte_offset % 4u); +} + +#if defined(INIT_SRC0_SHMEM_IQ1_M) || defined(INIT_SRC0_SHMEM_IQ1_S) +fn create_iq_gw4(dl: f32, gw: u32, shift_base: u32, delta: f32) -> vec4 { + return vec4( + f16(dl * (f32((bitcast(((gw >> (shift_base + 0u)) & 3u) << 30u) >> 30u)) + delta)), + f16(dl * (f32((bitcast(((gw >> (shift_base + 2u)) & 3u) << 30u) >> 30u)) + delta)), + f16(dl * (f32((bitcast(((gw >> (shift_base + 4u)) & 3u) << 30u) >> 30u)) + delta)), + f16(dl * (f32((bitcast(((gw >> (shift_base + 6u)) & 3u) << 30u) >> 30u)) + delta)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ4_XS) +fn create_iq_gw4(dl: f16, qs_u32: u32, shift_phase: u32) -> vec4 { + return vec4( + dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 0u)) & 0xFu]), + dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 8u)) & 0xFu]), + dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 16u)) & 0xFu]), + dl * f16(kvalues_iq4nl[(qs_u32 >> (4 * shift_phase + 24u)) & 0xFu]), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ2_XXS) +fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4 { + return vec4( + f32(get_byte(iq2xxs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)), + f32(get_byte(iq2xxs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)), + f32(get_byte(iq2xxs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)), + f32(get_byte(iq2xxs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ2_XS) +fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4 { + return vec4( + f32(get_byte(iq2xs_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)), + f32(get_byte(iq2xs_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)), + f32(get_byte(iq2xs_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)), + f32(get_byte(iq2xs_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ2_S) +fn create_iq_gw4(ig: u32, grid_phase: u32) -> vec4 { + return vec4( + f32(get_byte(iq2s_grid[(ig + grid_phase + 0u) / 4u], (ig + grid_phase + 0u) % 4u)), + f32(get_byte(iq2s_grid[(ig + grid_phase + 1u) / 4u], (ig + grid_phase + 1u) % 4u)), + f32(get_byte(iq2s_grid[(ig + grid_phase + 2u) / 4u], (ig + grid_phase + 2u) % 4u)), + f32(get_byte(iq2s_grid[(ig + grid_phase + 3u) / 4u], (ig + grid_phase + 3u) % 4u)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ3_XXS) +fn create_iq_gw4(ig: u32) -> vec4 { + return vec4( + f32(get_byte(iq3xxs_grid[ig], 0)), + f32(get_byte(iq3xxs_grid[ig], 1)), + f32(get_byte(iq3xxs_grid[ig], 2)), + f32(get_byte(iq3xxs_grid[ig], 3)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ3_S) +fn create_iq_gw4(ig: u32) -> vec4 { + return vec4( + f32(get_byte(iq3s_grid[ig], 0)), + f32(get_byte(iq3s_grid[ig], 1)), + f32(get_byte(iq3s_grid[ig], 2)), + f32(get_byte(iq3s_grid[ig], 3)), + ); +} +#endif + +#if defined(INIT_SRC0_SHMEM_IQ2_XXS) || defined(INIT_SRC0_SHMEM_IQ2_XS) || defined(INIT_SRC0_SHMEM_IQ2_S) \ +|| defined(INIT_SRC0_SHMEM_IQ3_XXS) || defined(INIT_SRC0_SHMEM_IQ3_S) +fn create_iq2_m4(signs: u32, mask_phase: u32) -> vec4 { + return vec4( + select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 0) & signs) != 0u), + select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 1) & signs) != 0u), + select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 2) & signs) != 0u), + select(1.0, -1.0, (get_byte(kmask_iq2xs[mask_phase], 3) & signs) != 0u), + ); +} +#endif fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { + for (var elem_idx = thread_id * NQ; elem_idx < TILE_SRC0_SHMEM; elem_idx += NQ * TOTAL_WORKGROUP_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); + let zero_vec4 = vec4(f16(0.0), f16(0.0), f16(0.0), f16(0.0)); + store_shmem_iquants(zero_vec4, elem_idx + 0u); + store_shmem_iquants(zero_vec4, elem_idx + 4u); + store_shmem_iquants(zero_vec4, elem_idx + 8u); + store_shmem_iquants(zero_vec4, elem_idx + 12u); continue; } let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let k_in_block = global_k % BLOCK_SIZE; // k_in_block % 16 == 0; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; + let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let d_scales_h = load_u32_at_src0(block_byte_base); +#if defined(INIT_SRC0_SHMEM_IQ4_XS) + let block_byte_base = src0_idx * 136u; // BLOCK_SIZE_BYTES = 136u; + let d_byte_base = block_byte_base + 0u; + let scales_l_byte_base = block_byte_base + 4u; + let qs_byte_base = block_byte_base + 8u; + + let d_scales_h = load_u32_at_src0_aligned(d_byte_base); let d = bitcast>(d_scales_h).x; let scales_h = d_scales_h >> 16u; - let ib = k_in_block / 32u; - let pos = k_in_block % 32u; + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let scales_l_word = load_u32_at_src0(block_byte_base + 4u); - let ls_lo = (get_byte(scales_l_word, ib / 2u) >> ((ib & 1u) * 4u)) & 0xFu; - let ls_hi = ((scales_h >> (2u * ib)) & 3u) << 4u; - let dl = d * f16(i32(ls_lo | ls_hi) - 32); + let scales_l_u32 = load_u32_at_src0_aligned(scales_l_byte_base); + let ls_lo = (get_byte(scales_l_u32, sub_block / 2u) >> (4u * (sub_block % 2u))) & 0xFu; + let ls_hi = ((scales_h >> (2u * sub_block)) & 3u) << 4u; + let dl = d * f16(i32(ls_lo | ls_hi) - 32); - let iqs = ib * 16u + (pos % 16u); - let nib_shift = (pos / 16u) * 4u; - let q_packed = load_u32_at_src0(block_byte_base + 8u + (iqs / 4u) * 4u); - let nib = (get_byte(q_packed, iqs % 4u) >> nib_shift) & 0xFu; + let qs_0_3_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 0u); + let qs_4_7_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 4u); + let qs_8_11_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 8u); + let qs_12_15_u32 = load_u32_at_src0_aligned(qs_byte_base + 16u * sub_block + 12u); - shmem[elem_idx] = dl * f16(kvalues_iq4nl[nib]); - } -} + store_shmem_iquants(create_iq_gw4(dl, qs_0_3_u32, phase), elem_idx + 0u); + store_shmem_iquants(create_iq_gw4(dl, qs_4_7_u32, phase), elem_idx + 4u); + store_shmem_iquants(create_iq_gw4(dl, qs_8_11_u32, phase), elem_idx + 8u); + store_shmem_iquants(create_iq_gw4(dl, qs_12_15_u32, phase), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ4_XS -#ifdef INIT_SRC0_SHMEM_IQ1_S -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 50u; +#if defined(INIT_SRC0_SHMEM_IQ1_S) + let block_byte_base = src0_idx * 50u; // BLOCK_SIZE_BYTES = 50u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; + let qh_byte_base = block_byte_base + 34u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let qh_u16 = load_u32_at_src0(qh_byte_base + sub_block * 2u) & 0xFFFFu; + let qs_u16 = load_u32_at_src0(qs_byte_base + sub_block * 4u + phase * 2u) & 0xFFFFu; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); + let dl = d * (2.0 * f32((qh_u16 >> 12u) & 7u) + 1.0); + let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u16 & 0x8000u) != 0u); - let ib = k_in_block / 32u; - let pos = k_in_block % 32u; - let l = pos / 8u; - let j = pos % 8u; + let gp0_grid_id = ((qs_u16 & 0xFFu) | (((qh_u16 >> (phase * 6u)) & 7u) << 8u)) * 8u; + let gp1_grid_id = (((qs_u16 >> 8) & 0xFFu) | (((qh_u16 >> (phase * 6u + 3u)) & 7u) << 8u)) * 8u; - let qh = load_u32_at_src0(block_byte_base + 34u + ib * 2u) & 0xFFFFu; - let dl = d * (2.0 * f32((qh >> 12u) & 7u) + 1.0); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh & 0x8000u) != 0u); + let gp0_gw = iq1_grid[(gp0_grid_id) / 16u]; + let gp1_gw = iq1_grid[(gp1_grid_id) / 16u]; - let qs_w = load_u32_at_src0(block_byte_base + 2u + ib * 4u); - let ig = (get_byte(qs_w, l) | (((qh >> (3u * l)) & 7u) << 8u)) * 8u; + let gp0_shift_base = (gp0_grid_id % 16u) * 2u; + let gp1_shift_base = (gp1_grid_id % 16u) * 2u; - let gw = iq1_grid[(ig + j) / 16u]; - let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; - let gs = bitcast(g << 30u) >> 30u; - - shmem[elem_idx] = f16(dl * (f32(gs) + delta)); - } -} + store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, delta), elem_idx + 0u); + store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, delta), elem_idx + 4u); + store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, delta), elem_idx + 8u); + store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, delta), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ1_S -#ifdef INIT_SRC0_SHMEM_IQ1_M -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 56u; +#if defined(INIT_SRC0_SHMEM_IQ1_M) + let block_byte_base = src0_idx * 56u; // BLOCK_SIZE_BYTES = 56u; + let qs_byte_base = block_byte_base + 0u; + let qh_byte_base = block_byte_base + 32u; + let scales_byte_base = block_byte_base + 48u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; - - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } - - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - - let scales0 = load_u32_at_src0(block_byte_base + 48u); - let scales1 = load_u32_at_src0(block_byte_base + 52u); + let scales0 = load_u32_at_src0_aligned(scales_byte_base); + let scales1 = load_u32_at_src0_aligned(scales_byte_base + 4u); let scale_packed = ((scales0 >> 12u) & 0xFu) | ((scales0 >> 24u) & 0x00F0u) | ((scales1 >> 4u) & 0x0F00u) | ((scales1 >> 16u) & 0xF000u); let d = f32(bitcast>(scale_packed).x); - let ib = k_in_block / 32u; - let pos = k_in_block % 32u; - let l = pos / 8u; - let j = pos % 8u; + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let scales = select(scales0, scales1, ib >= 4u); - let sw = (scales >> (16u * ((ib / 2u) % 2u))) & 0xFFFFu; - let s_pair = (sw >> (6u * (ib % 2u) + 3u * (l / 2u))) & 0x7u; - let dl = d * f32(2u * s_pair + 1u); + let scale_u32 = select(scales0, scales1, sub_block >= 4u); + let scale_u3 = (scale_u32 >> (16u * ((sub_block / 2u) % 2u) + 6u * (sub_block % 2u) + 3u * phase)) & 0x7u; + let dl = d * f32(2u * scale_u3 + 1u); - let qh_word = load_u32_at_src0(block_byte_base + 32u + (ib / 2u) * 4u); - let qh = qh_word >> (16u * (ib % 2u)); - let qh_nib = (qh >> (4u * l)) & 0xFu; + let qh_u8 = (load_u32_at_src0_aligned(qh_byte_base + 4u * (sub_block / 2u)) >> (16u * (sub_block % 2u) + 8u * phase)) & 0xFFu; + let qs_u16 = (load_u32_at_src0_aligned(qs_byte_base + 4u * sub_block) >> (16u * phase)) & 0xFFFFu; - let qs_w = load_u32_at_src0(block_byte_base + ib * 4u); - let idx = get_byte(qs_w, l) | ((qh_nib & 7u) << 8u); - let delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_nib & 0x8u) != 0u); + let gp0_grid_id = ((qs_u16 & 0xFFu) | ((qh_u8 & 7u) << 8u)) * 8u; + let gp0_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x8u) != 0u); - let ig = idx * 8u; - let gw = iq1_grid[(ig + j) / 16u]; - let g = (gw >> (((ig + j) % 16u) * 2u)) & 3u; - let gs = bitcast(g << 30u) >> 30u; + let gp1_grid_id = (((qs_u16 >> 8u) & 0xFFu) | (((qh_u8 >> 4u) & 7u) << 8u)) * 8u; + let gp1_delta = select(IQ1_DELTA, -IQ1_DELTA, (qh_u8 & 0x80u) != 0u); - shmem[elem_idx] = f16(dl * (f32(gs) + delta)); - } -} + let gp0_gw = iq1_grid[(gp0_grid_id) / 16u]; + let gp1_gw = iq1_grid[(gp1_grid_id) / 16u]; + + let gp0_shift_base = (gp0_grid_id % 16u) * 2u; + let gp1_shift_base = (gp1_grid_id % 16u) * 2u; + + store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 0u, gp0_delta), elem_idx + 0u); + store_shmem_iquants(create_iq_gw4(dl, gp0_gw, gp0_shift_base + 8u, gp0_delta), elem_idx + 4u); + store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 0u, gp1_delta), elem_idx + 8u); + store_shmem_iquants(create_iq_gw4(dl, gp1_gw, gp1_shift_base + 8u, gp1_delta), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ1_M -#ifdef INIT_SRC0_SHMEM_IQ2_XXS -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 66u; +#if defined(INIT_SRC0_SHMEM_IQ2_XXS) + let block_byte_base = src0_idx * 66u; // BLOCK_SIZE_BYTES = 66u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; - - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); - - let entry_idx = k_in_block / 8u; - let j = k_in_block % 8u; - - let ib = entry_idx & ~3u; - let l = entry_idx & 3u; - - let aux0 = load_u32_at_src0(block_byte_base + 2u + ib * 2u); - let aux1 = load_u32_at_src0(block_byte_base + 2u + (ib + 2u) * 2u); + let aux0 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 0u); + let aux1 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u); let db = d * (0.5 + f32(aux1 >> 28u)) * 0.25; - let ig = get_byte(aux0, l) * 8u; - let is = (aux1 >> (7u * l)) & 127u; - let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + let gp0_ig = get_byte(aux0, 2u * phase + 0u) * 8u; + let gp1_ig = get_byte(aux0, 2u * phase + 1u) * 8u; - let g = get_byte(iq2xxs_grid[(ig + j) / 4u], (ig + j) % 4u); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + let gp0_is = (aux1 >> (14u * phase + 0u)) & 127u; + let gp1_is = (aux1 >> (14u * phase + 7u)) & 127u; - shmem[elem_idx] = f16(db * f32(g) * m); - } -} + let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u); + let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u); + + let m_0_3_val4 = create_iq2_m4(gp0_signs, 0); + let m_4_7_val4 = create_iq2_m4(gp0_signs, 1); + let m_8_11_val4 = create_iq2_m4(gp1_signs, 0); + let m_12_15_val4 = create_iq2_m4(gp1_signs, 1); + + let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0); + let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4); + let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0); + let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4); + + store_shmem_iquants(vec4(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u); + store_shmem_iquants(vec4(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u); + store_shmem_iquants(vec4(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u); + store_shmem_iquants(vec4(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ2_XXS -#ifdef INIT_SRC0_SHMEM_IQ2_XS -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 74u; +#if defined(INIT_SRC0_SHMEM_IQ2_XS) + let block_byte_base = src0_idx * 74u; // BLOCK_SIZE_BYTES = 74u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; + let scales_byte_base = block_byte_base + 66u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu; + let db = d * (0.5 + f32(scale)) * 0.25; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); + let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase); - let entry_idx = k_in_block / 8u; - let j = k_in_block % 8u; + let gp0_ig = (qs_u32 & 0x1FFu) * 8u; + let gp1_ig = ((qs_u32 >> 16u) & 0x1FFu) * 8u; - let ib = entry_idx & ~3u; - let l = entry_idx & 3u; + let gp0_is = (qs_u32 >> 9u) & 0x7Fu; + let gp1_is = (qs_u32 >> 25u) & 0x7Fu; - let scales_word = load_u32_at_src0(block_byte_base + 66u + (ib / 16u) * 4u); - let s = get_byte(scales_word, (ib % 16u) / 4u); - let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); - let dl = d * (0.5 + f32(s_nib)) * 0.25; + let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u); + let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u); - let qs_word = load_u32_at_src0(block_byte_base + 2u + (ib + l) * 2u); - let qs_val = qs_word & 0xFFFFu; - let ig = (qs_val & 511u) * 8u; - let is = qs_val >> 9u; - let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + let m_0_3_val4 = create_iq2_m4(gp0_signs, 0); + let m_4_7_val4 = create_iq2_m4(gp0_signs, 1); + let m_8_11_val4 = create_iq2_m4(gp1_signs, 0); + let m_12_15_val4 = create_iq2_m4(gp1_signs, 1); - let g = get_byte(iq2xs_grid[(ig + j) / 4u], (ig + j) % 4u); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); + let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0); + let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4); + let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0); + let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4); - shmem[elem_idx] = f16(dl * f32(g) * m); - } -} + store_shmem_iquants(vec4(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u); + store_shmem_iquants(vec4(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u); + store_shmem_iquants(vec4(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u); + store_shmem_iquants(vec4(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ2_XS -#ifdef INIT_SRC0_SHMEM_IQ2_S -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 82u; +#if defined(INIT_SRC0_SHMEM_IQ2_S) + let block_byte_base = src0_idx * 82u; // BLOCK_SIZE_BYTES = 82u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; + let qh_byte_base = block_byte_base + 66u; + let scales_byte_base = block_byte_base + 74u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu; + let db = d * (0.5 + f32(scale)) * 0.25; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); + let qs_u16 = load_u32_at_src0(qs_byte_base + 4u * sub_block + 2u * phase) & 0xFFFFu; + let signs_u16 = load_u32_at_src0(qs_byte_base + 32u + 4u * sub_block + 2u * phase) & 0xFFFFu; + let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu; - let ib = k_in_block / 32u; - let l = (k_in_block % 32u) / 8u; - let j = k_in_block % 8u; + let gp0_ig = ((qs_u16 & 0xFFu) | ((qh_u4 & 0x3u) << 8u)) * 8u; + let gp1_ig = (((qs_u16 >> 8u) & 0xFFu) | ((qh_u4 & 0xCu) << 6u)) * 8u; - let scales_word = load_u32_at_src0(block_byte_base + 74u + (ib / 4u) * 4u); - let s = get_byte(scales_word, ib % 4u); - let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, (l / 2u) != 0u); - let dl = d * (0.5 + f32(s_nib)) * 0.25; + let gp0_signs = get_byte(signs_u16, 0); + let gp1_signs = get_byte(signs_u16, 1); - let qs_word = load_u32_at_src0(block_byte_base + 2u + ib * 4u); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 4u) * 4u); - let qh_b = (get_byte(qh_word, ib % 4u) << (8u - 2u * l)) & 0x300u; - let ig = (get_byte(qs_word, l) | qh_b) * 8u; + let m_0_3_val4 = create_iq2_m4(gp0_signs, 0); + let m_4_7_val4 = create_iq2_m4(gp0_signs, 1); + let m_8_11_val4 = create_iq2_m4(gp1_signs, 0); + let m_12_15_val4 = create_iq2_m4(gp1_signs, 1); - let signs_word = load_u32_at_src0(block_byte_base + 34u + ib * 4u); - let signs = get_byte(signs_word, l); + let gw_0_3_val4 = create_iq_gw4(gp0_ig, 0); + let gw_4_7_val4 = create_iq_gw4(gp0_ig, 4); + let gw_8_11_val4 = create_iq_gw4(gp1_ig, 0); + let gw_12_15_val4 = create_iq_gw4(gp1_ig, 4); - let g = get_byte(iq2s_grid[(ig + j) / 4u], (ig + j) % 4u); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[j / 4u], j % 4u) & signs) != 0u); - - shmem[elem_idx] = f16(dl * f32(g) * m); - } -} + store_shmem_iquants(vec4(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u); + store_shmem_iquants(vec4(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u); + store_shmem_iquants(vec4(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u); + store_shmem_iquants(vec4(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ2_S -#ifdef INIT_SRC0_SHMEM_IQ3_XXS -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 98u; +#if defined(INIT_SRC0_SHMEM_IQ3_XXS) + let block_byte_base = src0_idx * 98u; // BLOCK_SIZE_BYTES = 98u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase); + let sign_u32 = load_u32_at_src0(qs_byte_base + 64u + 4u * sub_block); + let db = d * (0.5 + f32(sign_u32 >> 28u)) * 0.5; - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); + let ig_0_3 = get_byte(qs_u32, 0); + let ig_4_7 = get_byte(qs_u32, 1); + let ig_8_11 = get_byte(qs_u32, 2); + let ig_12_15 = get_byte(qs_u32, 3); - let ib_pair = k_in_block / 32u; - let in_pair = k_in_block % 32u; - let l = in_pair / 8u; - let in_l = in_pair % 8u; - let k2 = in_l / 4u; - let j = in_l % 4u; + let gp0_is = (sign_u32 >> (14u * phase + 0u)) & 0x7Fu; + let gp1_is = (sign_u32 >> (14u * phase + 7u)) & 0x7Fu; - let ib = ib_pair * 2u; - let sc_sign_off = block_byte_base + 2u + (ib + 32u) * 2u; - let sc_sign = load_u32_at_src0(sc_sign_off); - let db = d * (0.5 + f32(sc_sign >> 28u)) * 0.5; - let is = (sc_sign >> (7u * l)) & 127u; - let signs = get_byte(ksigns_iq2xs[is / 4u], is % 4u); + let gp0_signs = get_byte(ksigns_iq2xs[gp0_is / 4u], gp0_is % 4u); + let gp1_signs = get_byte(ksigns_iq2xs[gp1_is / 4u], gp1_is % 4u); - let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 2u + l) * 2u) & 0xFFFFu; - let ig_byte = get_byte(ig_word, k2); - let g = get_byte(iq3xxs_grid[ig_byte], j); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); + let m_0_3_val4 = create_iq2_m4(gp0_signs, 0); + let m_4_7_val4 = create_iq2_m4(gp0_signs, 1); + let m_8_11_val4 = create_iq2_m4(gp1_signs, 0); + let m_12_15_val4 = create_iq2_m4(gp1_signs, 1); - shmem[elem_idx] = f16(db * f32(g) * m); - } -} + let gw_0_3_val4 = create_iq_gw4(ig_0_3); + let gw_4_7_val4 = create_iq_gw4(ig_4_7); + let gw_8_11_val4 = create_iq_gw4(ig_8_11); + let gw_12_15_val4 = create_iq_gw4(ig_12_15); + + store_shmem_iquants(vec4(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u); + store_shmem_iquants(vec4(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u); + store_shmem_iquants(vec4(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u); + store_shmem_iquants(vec4(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u); #endif // INIT_SRC0_SHMEM_IQ3_XXS -#ifdef INIT_SRC0_SHMEM_IQ3_S -const BLOCK_SIZE = 256u; -const BLOCK_SIZE_BYTES = 110u; +#if defined(INIT_SRC0_SHMEM_IQ3_S) + let block_byte_base = src0_idx * 110u; // BLOCK_SIZE_BYTES = 110u; + let d_byte_base = block_byte_base + 0u; + let qs_byte_base = block_byte_base + 2u; + let qh_byte_base = block_byte_base + 66u; + let signs_byte_base = block_byte_base + 74u; + let scales_byte_base = block_byte_base + 106u; -fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE) { - let tile_m = elem_idx / TILE_K; - let tile_k = elem_idx % TILE_K; - let global_m = offset_m + tile_m; - let global_k = k_outer + tile_k; + let d = load_f16_as_f32_at_src0(d_byte_base); - if (global_m >= params.m || global_k >= params.k) { - shmem[elem_idx] = f16(0.0); - continue; - } + let sub_block = k_in_block / 32u; + let phase = (k_in_block / NQ) % 2u; - let block_k = global_k / BLOCK_SIZE; - let k_in_block = global_k % BLOCK_SIZE; + let scale = (load_byte_at_src0_aligned(scales_byte_base + 1u * (sub_block / 2u)) >> (4u * (sub_block % 2u))) & 0xFu; + let db = d * (1.0 + 2.0 * f32(scale)); - let src0_idx = batch_offset + global_m * params.stride_01 + block_k; - let block_byte_base = src0_idx * BLOCK_SIZE_BYTES; - let d = load_f16_as_f32_at_src0(block_byte_base); + let qs_u32 = load_u32_at_src0(qs_byte_base + 8u * sub_block + 4u * phase); + let qh_u4 = (load_byte_at_src0_aligned(qh_byte_base + 1u * sub_block) >> (4u * phase)) & 0xFu; + let signs_u16 = (load_u32_at_src0(signs_byte_base + 4u * sub_block + 2u * phase)) & 0xFFFFu; - let ib = k_in_block / 64u; - let rest = k_in_block % 64u; - let k = rest / 32u; - let in_k = rest % 32u; - let l = in_k / 8u; - let in_l = in_k % 8u; - let k2 = in_l / 4u; - let j = in_l % 4u; + let ig_0_3 = ((qs_u32 >> 0u) & 0xFFu) | ((qh_u4 & 0x1u) << 8u); + let ig_4_7 = ((qs_u32 >> 8u) & 0xFFu) | ((qh_u4 & 0x2u) << 7u); + let ig_8_11 = ((qs_u32 >> 16u) & 0xFFu) | ((qh_u4 & 0x4u) << 6u); + let ig_12_15 = ((qs_u32 >> 24u) & 0xFFu) | ((qh_u4 & 0x8u) << 5u); - let scales_word = load_u32_at_src0(block_byte_base + 106u); - let s = get_byte(scales_word, ib); - let s_nib = select(s & 0xFu, (s >> 4u) & 0xFu, k != 0u); - let dl = d * (1.0 + 2.0 * f32(s_nib)); + let gp0_signs = get_byte(signs_u16, 0); + let gp1_signs = get_byte(signs_u16, 1); - let qh_word = load_u32_at_src0(block_byte_base + 66u + (ib / 2u) * 4u); - let qh_byte = get_byte(qh_word, (ib % 2u) * 2u + k); + let m_0_3_val4 = create_iq2_m4(gp0_signs, 0); + let m_4_7_val4 = create_iq2_m4(gp0_signs, 1); + let m_8_11_val4 = create_iq2_m4(gp1_signs, 0); + let m_12_15_val4 = create_iq2_m4(gp1_signs, 1); - let ig_word = load_u32_at_src0(block_byte_base + 2u + (ib * 8u + k * 4u + l) * 2u) & 0xFFFFu; - let ig_lo = get_byte(ig_word, 0u) | ((qh_byte << (8u - 2u * l)) & 256u); - let ig_hi = get_byte(ig_word, 1u) | ((qh_byte << (7u - 2u * l)) & 256u); - let ig = select(ig_lo, ig_hi, k2 != 0u); + let gw_0_3_val4 = create_iq_gw4(ig_0_3); + let gw_4_7_val4 = create_iq_gw4(ig_4_7); + let gw_8_11_val4 = create_iq_gw4(ig_8_11); + let gw_12_15_val4 = create_iq_gw4(ig_12_15); - let signs_word = load_u32_at_src0(block_byte_base + 74u + (ib * 2u + k) * 4u); - let signs = get_byte(signs_word, l); - - let g = get_byte(iq3s_grid[ig], j); - let m = select(1.0, -1.0, (get_byte(kmask_iq2xs[k2], j) & signs) != 0u); - - shmem[elem_idx] = f16(dl * f32(g) * m); + store_shmem_iquants(vec4(db * m_0_3_val4 * gw_0_3_val4), elem_idx + 0u); + store_shmem_iquants(vec4(db * m_4_7_val4 * gw_4_7_val4), elem_idx + 4u); + store_shmem_iquants(vec4(db * m_8_11_val4 * gw_8_11_val4), elem_idx + 8u); + store_shmem_iquants(vec4(db * m_12_15_val4 * gw_12_15_val4), elem_idx + 12u); +#endif // INIT_SRC0_SHMEM_IQ3_S } } -#endif // INIT_SRC0_SHMEM_IQ3_S +#endif // i-quants (super block size: 256)