diff --git a/examples/python/7.1_schedule.py b/examples/python/7.1_schedule.py index be229ee4da..2c2d6e0b0f 100755 --- a/examples/python/7.1_schedule.py +++ b/examples/python/7.1_schedule.py @@ -60,7 +60,13 @@ def _run_mxfp_gemm(gemm, shape): def _run_mxfp_gemm_preshuffle( - gemm, shape, all=False, only_scale=False, only_b=False, output_dtype=torch.float32 + gemm, + shape, + all=False, + only_scale=False, + only_b=False, + output_dtype=torch.float32, + transpose_output=False, ): """Run compiled GEMM kernel with preshuffled B and B_scale, verify against reference. @@ -68,30 +74,31 @@ def _run_mxfp_gemm_preshuffle( all - shuffle a_scale (x_scales), b_scale (w_scales), and b (w_t) only_scale - shuffle a_scale (x_scales) and b_scale (w_scales) only only_b - shuffle b_scale (w_scales) only + + When transpose_output is True, the kernel writes C^T [N, M] instead of C [M, N]. """ x, w, x_scales, w_scales = generate_gemm_afp4wfp4_inputs(shape) torch_out = torchScaledGemmMXFP4(x, w, x_scales, w_scales) w_t = w.T.contiguous() - # Apply b (w_t) preshuffle only when all=True w_t_ps = b_preshuffle(w_t) if all else w_t - # Apply a_scale shuffle when all=True or only_scale=True x_scales_ps = e8m0_shuffle(x_scales) if (all or only_scale) else x_scales - # Apply b_scale shuffle when all=True, only_scale=True, or only_b=True w_scales_ps = e8m0_shuffle(w_scales) if (all or only_scale or only_b) else w_scales x, w_t_ps = x.cuda(), w_t_ps.cuda() x_scales_ps, w_scales_ps = x_scales_ps.cuda(), w_scales_ps.cuda() - out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda() + if transpose_output: + out = torch.zeros(w_t_ps.shape[0], x.shape[0], dtype=output_dtype).cuda() + else: + out = torch.zeros(x.shape[0], w_t_ps.shape[0], dtype=output_dtype).cuda() gemm(x, x_scales_ps, w_t_ps, w_scales_ps, out) - torch.testing.assert_close( - torch_out, out.cpu(), check_dtype=False, check_device=False - ) + result = out.T.contiguous().cpu() if transpose_output else out.cpu() + torch.testing.assert_close(torch_out, result, check_dtype=False, check_device=False) def _get_8wave_shape_from_block(block): @@ -200,7 +207,7 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle( def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle_lds( - is_debug=False, shape=(1024, 1024, 8192), block=(256, 256, 256), dynamic=False + is_debug=False, shape=(1792, 5376, 4096), block=(256, 192, 256), dynamic=True ): """Double-buffered MXFP4 GEMM, 8 waves, ping-pong with stagger. A&B scales are preshuffled and read from global memory directly to VGPRs. @@ -213,25 +220,49 @@ def test_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle_lds( block, wave_shape=wave_shape, b_address_space=SHARED_ADDRESS_SPACE, + output_dtype=tkl.bf16, + transpose_output=True, ) options.specialize = True options.use_buffer_ops = True - options.minimize_shared_allocs = False + options.minimize_shared_allocs = True options.linearize_shared_access = True + options.wave_runtime = True + options.coalesce_epilogue_stores = True + options.dump_intermediates = "build/intermediates/caolesce_epi" if dynamic: options.dynamic_symbols = [tkl.sym.M, tkl.sym.N, tkl.sym.K] for sym in options.dynamic_symbols: del options.subs[sym] schedule = get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds( - use_stagger=True, shape=shape + use_stagger=True, shape=shape, block=block ) + UNROLL_FACTOR = tkl.sym.UNROLL_FACTOR + options.subs[UNROLL_FACTOR] = 2 + options.postprocess = """ + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["scf.for"]} in %arg0 : (!transform.any_op) -> !transform.any_op + transform.loop.unroll %0 { factor = %%UNROLL_FACTOR%% } : !transform.any_op + transform.yield + } + } + """ options.print_ir_after = "all" if is_debug else [] options = set_default_run_config(options) gemm = wave_compile(options, gemm, schedule) - _run_mxfp_gemm_preshuffle(gemm, shape, all=True) + with open( + "build/intermediates/caolesce_epi/gemm_mxfp4_dbuf_8wave_pingpong_mxfp_gemm_Bshuffle_lds.mlir", + "w", + ) as f: + f.write(gemm.asm) + + _run_mxfp_gemm_preshuffle( + gemm, shape, all=True, output_dtype=torch.bfloat16, transpose_output=True + ) mode = "dynamic" if dynamic else "static" print( f"MXFP GEMM double-buffer 8-wave ping pong with scales and B shuffling and B->LDS ({mode}) test passed!" diff --git a/wave_lang/kernel/compiler/wave_codegen/emitter.py b/wave_lang/kernel/compiler/wave_codegen/emitter.py index c97d311582..c99cc317c4 100644 --- a/wave_lang/kernel/compiler/wave_codegen/emitter.py +++ b/wave_lang/kernel/compiler/wave_codegen/emitter.py @@ -633,10 +633,13 @@ def add_emitter_subs( _emulate_ceildiv = bool(int(environ.get("WAVE_EMULATE_CEILDIV", 0))) _use_affine_expr = bool(int(environ.get("WAVE_USE_AFFINE_EXPR", 1))) +_magic_number_enabled = bool(int(environ.get("WAVE_MAGIC_NUMBER_DIV", 1))) _Rational = namedtuple("_Rational", ["numerator", "denominator"]) _ApplyExpr = namedtuple("_ApplyExpr", ["expr", "args"]) +_magic_number_cache: dict = {} + def gen_sympy_index(dynamics: dict[IndexSymbol, Value], expr: sympy.Expr) -> Value: use_affine_expr = _use_affine_expr @@ -714,6 +717,26 @@ def get_const_val(arg): return None + def _same_value(a, b) -> bool: + """Structural equality for emitter values (_ApplyExpr / OpResult). + + Detects when two independently-built IR values represent the same + computation. Used in _add to avoid denominator explosion when + adding _Rationals that share a common denominator. + """ + if isinstance(a, _ApplyExpr) and isinstance(b, _ApplyExpr): + if a.expr != b.expr or len(a.args) != len(b.args): + return False + return all(_same_value(x, y) for x, y in zip(a.args, b.args)) + if isinstance(a, (Value, OpResult)) and isinstance(b, (Value, OpResult)): + if a is b: + return True + a_val = get_const_val(a) + b_val = get_const_val(b) + if a_val is not None and b_val is not None: + return a_val == b_val + return False + overflow_flags = arith_d.IntegerOverflowFlags.nsw | arith_d.IntegerOverflowFlags.nuw def muli(lhs, rhs): @@ -778,16 +801,105 @@ def muli_expr(lhs, rhs): return op_expr(lhs, rhs, lambda a, b: a * b) + def _is_dynamic_divisor(val) -> bool: + """Check if a value is NOT a compile-time constant.""" + if isinstance(val, _ApplyExpr): + if all( + isinstance(a, OpResult) and get_const_val(a) is not None + for a in val.args + ): + return False + val = _get_ir_value(val) + if isinstance(val, OpResult): + return get_const_val(val) is None + return True + + def _mulhi_u32(n_i32, m_i32): + """Unsigned 32-bit multiply-high: (n * m) >> 32, via 64-bit multiply.""" + i64 = IntegerType.get_signless(64) + c32_i64 = arith_d.constant(i64, 32) + n_i64 = arith_d.extui(i64, n_i32) + m_i64 = arith_d.extui(i64, m_i32) + prod_i64 = arith_d.muli(n_i64, m_i64) + hi_i64 = arith_d.shrui(prod_i64, c32_i64) + i32 = IntegerType.get_signless(32) + return arith_d.trunci(i32, hi_i64) + + def _precompute_magic_number(divisor_index: Value): + """ + Compute magic = ceil(2^32 / d) from a dynamic divisor. + Returns (magic_i32, d_i32) both as i32 Values. + """ + i32 = IntegerType.get_signless(32) + i64 = IntegerType.get_signless(64) + d_i32 = arith_d.index_cast(i32, divisor_index) + d_i64 = arith_d.extui(i64, d_i32) + c1_i64 = arith_d.constant(i64, 1) + c32_i64 = arith_d.constant(i64, 32) + pow32 = arith_d.shli(c1_i64, c32_i64) + d_minus_1_i64 = arith_d.subi(d_i64, c1_i64) + numer_i64 = arith_d.addi(pow32, d_minus_1_i64) + magic_i64 = arith_d.divui(numer_i64, d_i64) + magic_i32 = arith_d.trunci(i32, magic_i64) + return magic_i32, d_i32 + + def _get_or_create_magic(divisor: Value): + """Get cached (magic_i32, d_i32) or compute and cache them.""" + key = id(divisor) + if key in _magic_number_cache: + return _magic_number_cache[key] + magic_i32, d_i32 = _precompute_magic_number(divisor) + _magic_number_cache[key] = (magic_i32, d_i32) + return magic_i32, d_i32 + + def _magic_div_and_rem(lhs_val, rhs_val): + """ + Compute (quotient, remainder) of lhs_val // rhs_val using + magic number multiplication: q = mulhi(n, magic), with a + one-step correction for exactness. + Returns (quotient_index, remainder_index). + """ + i32 = IntegerType.get_signless(32) + magic_i32, d_i32 = _get_or_create_magic(rhs_val) + n_i32 = arith_d.index_cast(i32, lhs_val) + q_i32 = _mulhi_u32(n_i32, magic_i32) + qd_i32 = arith_d.muli(q_i32, d_i32) + r_i32 = arith_d.subi(n_i32, qd_i32) + # Correction: ceil(2^32/d) can overestimate quotient by 1. + # Detect via unsigned remainder >= divisor (wraps on overestimate). + too_big = arith_d.cmpi(arith_d.CmpIPredicate.uge, r_i32, d_i32) + c1_i32 = arith_d.constant(i32, 1) + c0_i32 = arith_d.constant(i32, 0) + corr = arith_d.select(too_big, c1_i32, c0_i32) + q_final = arith_d.subi(q_i32, corr) + d_or_zero = arith_d.select(too_big, d_i32, c0_i32) + r_final = arith_d.addi(r_i32, d_or_zero) + q_index = arith_d.index_cast(IndexType.get(), q_final) + r_index = arith_d.index_cast(IndexType.get(), r_final) + return q_index, r_index + def rem_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.remsi(*_broadcast(lhs, rhs)) + if _magic_number_enabled and _is_dynamic_divisor(rhs): + lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs + rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs + _, r = _magic_div_and_rem(lhs_val, rhs_val) + return r + return op_expr(lhs, rhs, lambda a, b: a % b) def floordiv_expr(lhs, rhs): if not use_affine_expr or not check_index_types(lhs, rhs): return arith_d.divsi(*_broadcast(lhs, rhs)) + if _magic_number_enabled and _is_dynamic_divisor(rhs): + lhs_val = _get_ir_value(lhs) if isinstance(lhs, _ApplyExpr) else lhs + rhs_val = _get_ir_value(rhs) if isinstance(rhs, _ApplyExpr) else rhs + q, _ = _magic_div_and_rem(lhs_val, rhs_val) + return q + return op_expr(lhs, rhs, lambda a, b: AffineExpr.get_floor_div(a, b)) def ceildiv_expr(lhs, rhs): @@ -820,6 +932,9 @@ def _add(lhs, rhs): numerator = add_expr(numerator, rhs.numerator) return _Rational(numerator, rhs.denominator) elif is_rational_lhs and is_rational_rhs: + if _same_value(lhs.denominator, rhs.denominator): + numerator = add_expr(lhs.numerator, rhs.numerator) + return _Rational(numerator, lhs.denominator) lhs_numerator = muli_expr(lhs.numerator, rhs.denominator) rhs_numerator = muli_expr(rhs.numerator, lhs.denominator) numerator = add_expr(lhs_numerator, rhs_numerator) diff --git a/wave_lang/kernel/compiler/wave_codegen/read_write.py b/wave_lang/kernel/compiler/wave_codegen/read_write.py index 76386a8793..b7a28d1fe1 100644 --- a/wave_lang/kernel/compiler/wave_codegen/read_write.py +++ b/wave_lang/kernel/compiler/wave_codegen/read_write.py @@ -14,6 +14,7 @@ from wave_lang.kernel.wave.utils.graph_utils import propagate_loop_carried_vars from wave_lang.support.ir_imports import ( Attribute, + BF16Type, DenseElementsAttr, IndexType, InsertionPoint, @@ -31,6 +32,7 @@ gpu_d, llvm_d, memref_d, + rocdl_d, vector_d, ) from .ir_utils import ( @@ -1143,6 +1145,24 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): ) use_llvm_store = flags != MemoryAccessFlags.NONE + + is_shared = get_custom(memory).type.address_space == SHARED_ADDRESS_SPACE + is_bf16 = isinstance(element_type, BF16Type) + + if not is_shared and is_bf16 and getattr(node, "_permlane_pack_global", False): + _write_permlane_pack_to_global( + emitter, + insert_vector, + kb_dest, + output_shape, + start_indices, + start_indices_wg, + start_indices_th, + get_custom(memory), + index, + ) + return + if use_llvm_store: _create_llvm_read_write( kb_dest, kb_ir_type, start_indices, insert_type, flags, insert_vector @@ -1164,6 +1184,115 @@ def handle_write(emitter: WaveEmitter, node: fx.Node): ) +def _write_permlane_pack_to_global( + emitter: WaveEmitter, + insert_vector: Value, + kb_dest: Value, + output_shape: tuple, + start_indices: tuple, + start_indices_wg: tuple, + start_indices_th: tuple, + memory_custom, + index: dict, +): + """Pack two lanes' bf16 values via permlane16_swap for wide global stores. + + MMA accumulator layout (F32_16x16x128_F8F6F4) gives each thread 4 + consecutive M values. Lanes are grouped by 16: lanes 0-15 own M=0-3, + lanes 16-31 own M=4-7, etc. ``v_permlane16_swap_b32`` exchanges data + between paired groups, giving each lane 8 consecutive M values that + can be written as a single ``buffer_store_dwordx4`` (128 bits). + + Both lane halves produce identical data at the same address (benign + duplicate store): + + - Lower half (lanes 0-15 in each 32-lane group): + data = [own, partner], address = thread's original M index. + - Upper half (lanes 16-31): + data = [partner, own], address = original M index - 4. + + This dual-write avoids divergent control flow (no scf.if / exec + masking needed). The buffer descriptor's ``valid_bytes`` handles + out-of-bounds suppression for dynamic shapes. + + Precondition: M must be the innermost (last) memory dimension with + stride 1 (i.e. transpose_output=True, shape [N, M]). + """ + bf16_type = BF16Type.get() + i32_type = IntegerType.get_signless(32) + idx_type = IndexType.get() + v2i32_type = VectorType.get([2], i32_type) + v4i32_type = VectorType.get([4], i32_type) + v8bf16_type = VectorType.get([8], bf16_type) + + # Bitcast 4 x bf16 -> 2 x i32 so permlane16_swap can operate on dwords. + i32_vec = vector_d.bitcast(v2i32_type, insert_vector) + own_lo = vector_d.extract(i32_vec, static_position=[0], dynamic_position=[]) + own_hi = vector_d.extract(i32_vec, static_position=[1], dynamic_position=[]) + + # Exchange dwords with the partner lane 16 positions apart. + swap_type = llvm_d.StructType.get_literal([i32_type, i32_type]) + partner_lo = llvm_d.extractvalue( + i32_type, rocdl_d.permlane16_swap(swap_type, own_lo, own_lo, False, False), [0] + ) + partner_hi = llvm_d.extractvalue( + i32_type, rocdl_d.permlane16_swap(swap_type, own_hi, own_hi, False, False), [0] + ) + + # Classify this lane as lower (0-15) or upper (16-31) within each + # 32-lane half-wave. + lane_in_wave = arith_d.remui(emitter.thread_ids[0], arith_d.constant(idx_type, 64)) + half_pos = arith_d.remui(lane_in_wave, arith_d.constant(idx_type, 32)) + is_lower = arith_d.cmpi( + arith_d.CmpIPredicate.ult, half_pos, arith_d.constant(idx_type, 16) + ) + + # Both halves build identical 8-bf16 vectors, but in complementary + # order so they land at the same memory address: + # lower: [own_lo, own_hi, partner_lo, partner_hi] @ M + # upper: [partner_lo, partner_hi, own_lo, own_hi] @ M - 4 + d0 = arith_d.select(is_lower, own_lo, partner_lo) + d1 = arith_d.select(is_lower, own_hi, partner_hi) + d2 = arith_d.select(is_lower, partner_lo, own_lo) + d3 = arith_d.select(is_lower, partner_hi, own_hi) + + wide_i32 = vector_d.from_elements(v4i32_type, [d0, d1, d2, d3]) + wide_vec = vector_d.bitcast(v8bf16_type, wide_i32) + + # Adjust the M (last) dimension index for the upper half so both + # halves target the same 8-element span starting at the lower half's + # M base. M is contiguous (stride 1), so subtracting 4 elements + # from the index subtracts 4 from the linearized element offset. + four = arith_d.constant(idx_type, 4) + + adj_th = list(start_indices_th) + adj_th[-1] = arith_d.select(is_lower, adj_th[-1], arith_d.subi(adj_th[-1], four)) + + adj_full = list(start_indices) + adj_full[-1] = arith_d.select( + is_lower, adj_full[-1], arith_d.subi(adj_full[-1], four) + ) + + # mask=None: the buffer descriptor encodes valid_bytes for the full + # output buffer, so OOB stores at tile edges are silently dropped by + # hardware. The original Write node's mask was sized for 4 elements, + # incompatible with our 8-element vector. + _create_vec_read_write( + emitter, + output_shape, + kb_dest, + wide_vec, + None, + tuple(adj_full), + start_indices_wg, + tuple(adj_th), + 8, + memory_custom, + None, + node_index=index, + ) + + def assume_index_subgroup_uniform(value: Value, element_type: IrType) -> Value: res = gpu_d.subgroup_broadcast(value, gpu_d.BroadcastType.first_active_lane) return res diff --git a/wave_lang/kernel/wave/coalesce_epilogue_stores.py b/wave_lang/kernel/wave/coalesce_epilogue_stores.py new file mode 100644 index 0000000000..e524b0d638 --- /dev/null +++ b/wave_lang/kernel/wave/coalesce_epilogue_stores.py @@ -0,0 +1,48 @@ +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +""" +Graph pass that coalesces epilogue bf16 stores via permlane16_swap. + +Marks eligible Write nodes so the codegen combines each thread's 4 bf16 +values with its partner lane's (16 lanes apart) via v_permlane16_swap_b32, +producing 8 consecutive bf16 written as a single buffer_store_dwordx4. +No LDS staging or barriers required. + +Precondition: the output memory must have M as the innermost (contiguous) +dimension (i.e. transpose_output=True producing [N, M] layout) so that 8 +consecutive bf16 elements span 8 adjacent M rows. +""" + +from .._support.tracing import CapturedTrace +from ..lang.global_symbols import GLOBAL_ADDRESS_SPACE +from ..ops.wave_ops import Write, get_custom +from .region_canonicalization import RegionFormat, requires_region_format +from .utils.symbol_utils import subs_idxc + + +@requires_region_format(RegionFormat.SCHEDULE_SIGNATURE_PLACEHOLDERS) +def coalesce_epilogue_stores(trace: CapturedTrace): + """Tag epilogue bf16 global writes for permlane16_swap packing. + + Walks the root graph and sets ``_permlane_pack_global = True`` on + every Write node that targets global memory with bf16 dtype. + The codegen in ``_write_permlane_pack_to_global`` handles the rest. + """ + import wave_lang.kernel.lang as tkl + + root_graph = trace.get_root_graph() + + for node in root_graph.nodes: + if node.op != "call_function": + continue + custom = get_custom(node) + if not isinstance(custom, Write): + continue + mem_type = custom.memory_type + if ( + subs_idxc(mem_type.address_space) == GLOBAL_ADDRESS_SPACE + and mem_type.dtype == tkl.bf16 + ): + node._permlane_pack_global = True diff --git a/wave_lang/kernel/wave/compile.py b/wave_lang/kernel/wave/compile.py index c71ddc9e74..e23add96c2 100644 --- a/wave_lang/kernel/wave/compile.py +++ b/wave_lang/kernel/wave/compile.py @@ -529,6 +529,12 @@ def build_graph_passes( partial(partition_ops_with_gpr_offsets, trace, launchable.constraints), partial(partition_strided_operators, trace, launchable.constraints), partial(remove_chained_extractslice, trace), + partial( + merge_contiguous_reads, + trace, + launchable.constraints, + options.target, + ), ] graph_passes += [ @@ -583,6 +589,11 @@ def build_graph_passes( options.minimize_shared_allocs, ), ] + if options.coalesce_epilogue_stores: + from .coalesce_epilogue_stores import coalesce_epilogue_stores + + graph_passes.append(partial(coalesce_epilogue_stores, trace)) + graph_passes += [ partial( add_shared_memory_barriers, @@ -596,6 +607,9 @@ def build_graph_passes( partial( partition_gather_like_ops, trace, launchable.constraints, options.target ), + ] + + graph_passes += [ partial( generate_bounds_exprs, trace, diff --git a/wave_lang/kernel/wave/compile_options.py b/wave_lang/kernel/wave/compile_options.py index fafb01453b..4019e14494 100644 --- a/wave_lang/kernel/wave/compile_options.py +++ b/wave_lang/kernel/wave/compile_options.py @@ -104,6 +104,8 @@ class WaveCompileOptions: specialize: bool = False eliminate_epilogue: bool = False + coalesce_epilogue_stores: bool = False + # Cluster barrier signal/wait delay in number of loop iterations # None - no barriers inside the loop # 0 - signal and wait on same iteration diff --git a/wave_lang/kernel/wave/gather_to_shared.py b/wave_lang/kernel/wave/gather_to_shared.py index 23e3345ca4..ee1c633d54 100644 --- a/wave_lang/kernel/wave/gather_to_shared.py +++ b/wave_lang/kernel/wave/gather_to_shared.py @@ -27,8 +27,10 @@ Write, get_custom, ) +from ..wave.assumptions import get_divisibility_subs from ..wave.constraints import ( Constraint, + DistributionConstraint, TilingConstraint, WaveConstraint, WorkgroupConstraint, @@ -49,6 +51,7 @@ remove_thread_indexing, remove_global_indexing, ) +from .generate_bounds_exprs import is_divisible from .utils.general_utils import is_gather from .utils.graph_utils import DCE from .utils.symbol_utils import subs_idxc @@ -542,6 +545,18 @@ def gather_to_shared( constraints, read.index, vector_shapes, symbolic_shape ) + # Remove bounds where divisibility assumptions prove full tile alignment + fwd, _ = get_divisibility_subs(constraints) + if bounds and fwd: + for c in constraints: + if ( + isinstance(c, DistributionConstraint) + and c.dim in bounds + and is_divisible(c.dim, c.tile_size, fwd) + ): + del bounds[c.dim] + bounds = bounds or None + logger.info(f"bounds={bounds}") fastest_dim_bound = bounds.get(symbolic_shape[-1], None) if bounds else None diff --git a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py index 6069941284..f50bb419dd 100755 --- a/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py +++ b/wave_lang/kernel/wave/schedules/gemm_mxfp4_double_buffer.py @@ -44,8 +44,9 @@ def mxfp4_dbuf_schedule(): global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) - # Matrix A scale + # Matrix A scale (global -> VGPR reads) all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + read_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) global_to_shared_a_scale = tkw.filter_nodes( all_read_a_scale, node_type=tkw.GatherToLDS ) @@ -56,8 +57,9 @@ def mxfp4_dbuf_schedule(): global_to_shared_b = tkw.filter_nodes(all_read_b, node_type=tkw.GatherToLDS) shared_load_b = tkw.filter_nodes(all_read_b, node_type=tkw.Read) - # Matrix B scale + # Matrix B scale (global -> VGPR reads) all_read_b_scale = tkw.get_node_by_tag("read_b_scale") + read_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) global_to_shared_b_scale = tkw.filter_nodes( all_read_b_scale, node_type=tkw.GatherToLDS ) @@ -276,16 +278,20 @@ def mxfp4_dbuf_schedule(): global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) - # Matrix A scale + # Matrix A scale (global -> VGPR reads) all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + read_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) + read_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) # Matrix B data all_read_b = tkw.get_node_by_tag("read_b") global_to_shared_b = tkw.filter_nodes(all_read_b, node_type=tkw.GatherToLDS) shared_load_b = tkw.filter_nodes(all_read_b, node_type=tkw.Read) - # Matrix B scale + # Matrix B scale (global -> VGPR reads) all_read_b_scale = tkw.get_node_by_tag("read_b_scale") + read_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) + read_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) # Bitcast operations (needed alongside compute) bitcast_a = tkw.get_node_by_tag("bitcast_a") @@ -308,8 +314,8 @@ def mxfp4_dbuf_schedule(): ( global_to_shared_a, global_to_shared_b, - all_read_a_scale, - all_read_b_scale, + read_a_scale, + read_b_scale, ), (), (), @@ -342,11 +348,11 @@ def mxfp4_dbuf_schedule(): loop_shared_load_b = tkw.filter_nodes( shared_load_b, subgraph=pipeline_loop.KERNEL ) - loop_all_read_a_scale = tkw.filter_nodes( - all_read_a_scale, subgraph=pipeline_loop.KERNEL + loop_read_a_scale = tkw.filter_nodes( + read_a_scale, subgraph=pipeline_loop.KERNEL ) - loop_all_read_b_scale = tkw.filter_nodes( - all_read_b_scale, subgraph=pipeline_loop.KERNEL + loop_read_b_scale = tkw.filter_nodes( + read_b_scale, subgraph=pipeline_loop.KERNEL ) loop_bitcast_a = tkw.filter_nodes(bitcast_a, subgraph=pipeline_loop.KERNEL) loop_bitcast_a_scale = tkw.filter_nodes( @@ -367,11 +373,11 @@ def mxfp4_dbuf_schedule(): loop_shared_load_b_0, loop_shared_load_b_1 = tkw.partition_by_dim( loop_shared_load_b, dim=K, num_partitions=2 ) - loop_all_read_a_scale_0, loop_all_read_a_scale_1 = tkw.partition_by_dim( - loop_all_read_a_scale, dim=K, num_partitions=2 + loop_read_a_scale_0, loop_read_a_scale_1 = tkw.partition_by_dim( + loop_read_a_scale, dim=K, num_partitions=2 ) - loop_all_read_b_scale_0, loop_all_read_b_scale_1 = tkw.partition_by_dim( - loop_all_read_b_scale, dim=K, num_partitions=2 + loop_read_b_scale_0, loop_read_b_scale_1 = tkw.partition_by_dim( + loop_read_b_scale, dim=K, num_partitions=2 ) loop_bitcast_a_0, loop_bitcast_a_1 = tkw.partition_by_dim( @@ -409,8 +415,8 @@ def mxfp4_dbuf_schedule(): loop_bitcast_a_scale_0, loop_bitcast_b_0, loop_bitcast_b_scale_0, - loop_all_read_a_scale_0, # prefetch A & B scales for next iteration - loop_all_read_b_scale_0, + loop_read_a_scale_0, # prefetch A & B scales for next iteration + loop_read_b_scale_0, tkw.SchedulingBarrier([]), ] ) @@ -446,8 +452,8 @@ def mxfp4_dbuf_schedule(): loop_bitcast_a_scale_1, loop_bitcast_b_1, loop_bitcast_b_scale_1, - loop_all_read_a_scale_1, - loop_all_read_b_scale_1, + loop_read_a_scale_1, + loop_read_b_scale_1, tkw.SchedulingBarrier([]), tkw.WorkgroupBarrier(), tkw.SchedulingBarrier([]), @@ -507,14 +513,16 @@ def mxfp4_dbuf_schedule(): global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) - # Matrix A scale + # Matrix A scale (global -> VGPR reads) all_read_a_scale = tkw.get_node_by_tag("read_a_scale") + read_a_scale = tkw.filter_nodes(all_read_a_scale, node_type=tkw.Read) # Matrix B data all_read_b = tkw.get_node_by_tag("read_b") - # Matrix B scale + # Matrix B scale (global -> VGPR reads) all_read_b_scale = tkw.get_node_by_tag("read_b_scale") + read_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) # Bitcast operations (needed alongside compute) bitcast_a = tkw.get_node_by_tag("bitcast_a") @@ -683,7 +691,7 @@ def mxfp4_dbuf_schedule(): def get_mxfp4_dbuf_pingpong_schedule_Bshuffled_lds( - use_stagger: bool = True, shape: tuple = None + use_stagger: bool = True, shape: tuple = None, block: tuple = None ): """Return a double-buffered MXFP4 schedule for wave_compile(). Same as get_mxfp4_dbuf_pingpong_schedule_Bshuffled(), but B data is read @@ -710,7 +718,7 @@ def mxfp4_dbuf_schedule(): global_to_shared_a = tkw.filter_nodes(all_read_a, node_type=tkw.GatherToLDS) shared_load_a = tkw.filter_nodes(all_read_a, node_type=tkw.Read) - # Matrix A scale + # Matrix A scale (global -> VGPR reads) all_read_a_scale = tkw.get_node_by_tag("read_a_scale") # Matrix B data - GatherToLDS (global->shared) + Read (shared load) @@ -718,8 +726,9 @@ def mxfp4_dbuf_schedule(): global_to_shared_b = tkw.filter_nodes(all_read_b, node_type=tkw.GatherToLDS) shared_load_b = tkw.filter_nodes(all_read_b, node_type=tkw.Read) - # Matrix B scale + # Matrix B scale (global -> VGPR reads) all_read_b_scale = tkw.get_node_by_tag("read_b_scale") + read_b_scale = tkw.filter_nodes(all_read_b_scale, node_type=tkw.Read) # Bitcast operations (needed alongside compute) bitcast_a = tkw.get_node_by_tag("bitcast_a") @@ -807,79 +816,236 @@ def mxfp4_dbuf_schedule(): loop_bitcast_b, dim=K, num_partitions=2 ) + # Count only actual global->VGPR Read nodes from the scheduled scale groups. + loop_a_scale_reads = tkw.filter_nodes(loop_all_read_a_scale, node_type=tkw.Read) + loop_b_scale_reads = tkw.filter_nodes(loop_all_read_b_scale, node_type=tkw.Read) + number_outstanding_loads_to_vgpr = len(loop_a_scale_reads) + len( + loop_b_scale_reads + ) + safe = 0 + if block is not None and block == (256, 192, 256): + safe = 2 + + if block is not None and block == (256, 160, 256): + safe = 5 + # If the bus gets congested and cluster memory dependency are affected, we must add a second barrier to fix the timing and prevent incorrect output results. # In case a second a second workgroup barrier is needed, another schedule is created to hide the latency of that second barrier, by scheduling safe ds_read ops before the second barrier (see get_mxfp4_dbuf_mixed_pingpong_schedule). - use_extra_barrier = True - # Build cluster 0: first K-partition loads + bitcasts + GatherToLDS - cluster_0_ops = [ - tkw.SchedulingBarrier([]), - tkw.MemoryCounterWait(load=0), - tkw.WorkgroupBarrier(), - ] - if use_extra_barrier: - cluster_0_ops.append(tkw.WorkgroupBarrier()) - cluster_0_ops.extend( - [ - loop_global_to_shared, - tkw.SchedulingBarrier([]), - loop_shared_load_a_0, - loop_shared_load_b_0, - loop_bitcast_a_0, - loop_bitcast_a_scale, - loop_bitcast_b_0, - loop_bitcast_b_scale, - loop_all_read_a_scale, # prefetch A & B scales for next iteration - loop_all_read_b_scale, + use_extra_barrier = False + + if block is not None and block == (256, 192, 256): + + cluster_0_ops = [ tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=number_outstanding_loads_to_vgpr), + tkw.WorkgroupBarrier(), ] - ) - if use_stagger: + if use_extra_barrier: + cluster_0_ops.append(tkw.WorkgroupBarrier()) cluster_0_ops.extend( [ - tkw.WorkgroupBarrier(), + loop_global_to_shared, + tkw.SchedulingBarrier([]), + loop_shared_load_a_0, + loop_shared_load_b_0, + loop_bitcast_a_0, + loop_bitcast_a_scale, + loop_bitcast_b_0, + loop_bitcast_b_scale, tkw.SchedulingBarrier([]), ] ) + if use_stagger: + cluster_0_ops.extend( + [ + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ] + ) - clusters = [ - # Cluster 0: First K-partition shared loads/bitcasts + async GatherToLDS - tkw.cluster(cluster_0_ops), - # Cluster 1: First K-partition scaled_mma (high priority) - tkw.cluster( - [ - tkw.SetWavePrio(1), - loop_scaled_mma_0, - tkw.SetWavePrio(0), - tkw.SchedulingBarrier([]), - tkw.WorkgroupBarrier(), - tkw.SchedulingBarrier([]), - ], - ), - # Cluster 2: Second K-partition shared loads/bitcasts - tkw.cluster( + clusters = [ + tkw.cluster(cluster_0_ops), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_0, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SchedulingBarrier([]), + loop_all_read_a_scale, + loop_all_read_b_scale, + loop_shared_load_a_1, + loop_shared_load_b_1, + loop_bitcast_a_1, + loop_bitcast_b_1, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait( + load=(number_outstanding_loads_to_vgpr + safe) + ), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_1, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + ], + ), + ] + tkw.insert_before( + pipeline_loop.KERNEL, + tkw.MemoryCounterWait(load=number_outstanding_loads_to_vgpr), + ) + + elif block is not None and block == (256, 160, 256): + cluster_0_ops = [ + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=number_outstanding_loads_to_vgpr), + tkw.WorkgroupBarrier(), + ] + if use_extra_barrier: + cluster_0_ops.append(tkw.WorkgroupBarrier()) + cluster_0_ops.extend( [ + loop_global_to_shared, tkw.SchedulingBarrier([]), + loop_shared_load_a_0, + loop_shared_load_b_0, + loop_bitcast_a_0, + loop_bitcast_a_scale, + loop_bitcast_b_0, + loop_bitcast_b_scale, loop_shared_load_a_1, loop_shared_load_b_1, - loop_bitcast_a_1, - loop_bitcast_b_1, tkw.SchedulingBarrier([]), - tkw.WorkgroupBarrier(), - tkw.SchedulingBarrier([]), - ], - ), - # Cluster 3: Second K-partition scaled_mma (high priority) - tkw.cluster( + ] + ) + if use_stagger: + cluster_0_ops.extend( + [ + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ] + ) + + clusters = [ + tkw.cluster(cluster_0_ops), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_0, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SchedulingBarrier([]), + loop_bitcast_a_1, + loop_bitcast_b_1, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait( + load=(number_outstanding_loads_to_vgpr + safe) + ), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_1, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + ], + ), + ] + tkw.insert_before( + pipeline_loop.KERNEL, + tkw.MemoryCounterWait(load=number_outstanding_loads_to_vgpr), + ) + + else: + + cluster_0_ops = [ + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=0), + tkw.WorkgroupBarrier(), + ] + if use_extra_barrier: + cluster_0_ops.append(tkw.WorkgroupBarrier()) + cluster_0_ops.extend( [ - tkw.SetWavePrio(1), - loop_scaled_mma_1, - tkw.SetWavePrio(0), + loop_global_to_shared, tkw.SchedulingBarrier([]), - ], - ), - ] + loop_shared_load_a_0, + loop_shared_load_b_0, + loop_bitcast_a_0, + loop_bitcast_a_scale, + loop_bitcast_b_0, + loop_bitcast_b_scale, + loop_all_read_a_scale, + loop_all_read_b_scale, + tkw.SchedulingBarrier([]), + ] + ) + if use_stagger: + cluster_0_ops.extend( + [ + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ] + ) + + clusters = [ + tkw.cluster(cluster_0_ops), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_0, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SchedulingBarrier([]), + loop_shared_load_a_1, + loop_shared_load_b_1, + loop_bitcast_a_1, + loop_bitcast_b_1, + tkw.SchedulingBarrier([]), + tkw.MemoryCounterWait(load=(number_outstanding_loads_to_vgpr)), + tkw.WorkgroupBarrier(), + tkw.SchedulingBarrier([]), + ], + ), + tkw.cluster( + [ + tkw.SetWavePrio(1), + loop_scaled_mma_1, + tkw.SetWavePrio(0), + tkw.SchedulingBarrier([]), + ], + ), + ] + tkw.insert_before(pipeline_loop.KERNEL, tkw.MemoryCounterWait(load=0)) # Insert barriers at loop boundaries + tkw.insert_before(pipeline_loop.KERNEL, tkw.WorkgroupBarrier()) tkw.insert_after(pipeline_loop.KERNEL, tkw.SharedMemoryBarrier()) diff --git a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py index b7b013460b..42a92af385 100755 --- a/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py +++ b/wave_lang/kernel/wave/templates/tagged_mxfp4_gemm.py @@ -149,6 +149,8 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( b_preshuffled: bool = False, reorder_workgroups: bool = False, group_size_n=32, + output_dtype=tkl.f32, + transpose_output: bool = False, ): """Shared implementation: preshuffle scales only, or scales + B data. @@ -159,6 +161,9 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( is controlled by the selected address spaces (`a_address_space` and `b_address_space`). + When transpose_output is True, the output memory is [N, M] instead of [M, N], + producing C^T in row-major layout. This makes per-lane MMA accumulator + elements contiguous in the M (fast) dimension of the output. """ M = tkl.sym.M N = tkl.sym.N @@ -179,6 +184,19 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( constraints += [tkw.WaveConstraint(N, BLOCK_N / wave_shape[1])] constraints += [tkw.HardwareConstraint(threads_per_wave=64, mma_type=mfma_variant)] + # Divisibility assumptions for M, N, K (no effect for static shapes). + constraints += [tkw.Assumption(Eq(M % 32, 0))] + constraints += [tkw.Assumption(Eq(N % 32, 0))] + constraints += [tkw.Assumption(Eq(K % 256, 0))] + + # Include assumption that K is divisible by BLOCK_K to allow gather_to_shared ops to omit masking predicates. + constraints += [tkw.Assumption(Eq(K % BLOCK_K, 0))] + constraints += [tkw.Assumption(Eq(M % BLOCK_M, 0))] + constraints += [tkw.Assumption(Eq(N % BLOCK_N, 0))] + + # K is always large enough for software pipelining. + constraints += [tkw.Assumption(K > BLOCK_K * 6)] + if reorder_workgroups: new_wg0, new_wg1 = _reorder_mxfp4_workgroups( M, N, BLOCK_M, BLOCK_N, GROUP_SIZE_N @@ -243,13 +261,26 @@ def _get_tagged_mxfp4_gemm_preshuffle_scales_impl( outputs={K: k_s, N: n_s}, ) + c_dim_0, c_dim_1 = (N, M) if transpose_output else (M, N) + + if transpose_output: + c_it_m = tkw.IndexMapping.iterator(0) + c_it_n = tkw.IndexMapping.iterator(1) + c_write_mapping = tkw.IndexMapping( + num_iterators=2, + inputs={M: c_it_m, N: c_it_n}, + outputs={N: c_it_n, M: c_it_m}, + ) + else: + c_write_mapping = None + @tkw.wave(constraints) def gemm( a: tkl.Memory[M, K / 2, A_ADDRESS_SPACE, tkl.i8], a_scale: tkl.Memory[M, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], b: tkl.Memory[N, K / 2, B_ADDRESS_SPACE, tkl.i8], b_scale: tkl.Memory[N, K / 32, GLOBAL_ADDRESS_SPACE, tkl.i8], - c: tkl.Memory[M, N, C_ADDRESS_SPACE, tkl.f32], + c: tkl.Memory[c_dim_0, c_dim_1, C_ADDRESS_SPACE, output_dtype], ): c_reg = tkl.Register[M, N, tkl.f32](0.0) @@ -273,7 +304,13 @@ def repeat( ) return acc - tkw.write(repeat, c) + if output_dtype == tkl.bf16: + repeat = tkw.cast(repeat, tkl.bf16) + + if c_write_mapping is not None: + tkw.write(repeat, c, mapping=c_write_mapping, elements_per_thread=4) + else: + tkw.write(repeat, c) hyperparams = { A_ADDRESS_SPACE: a_address_space, @@ -290,7 +327,7 @@ def repeat( M: shape[0], N: shape[1], K: shape[2], - K_SCALE_SHUFFLED: (((shape[2] // 32) + 7) // 8) * 8, + K_SCALE_SHUFFLED: (((K // 32) + 7) // 8) * 8, } if b_preshuffled: hyperparams[K_PACKED] = K // 2 @@ -348,6 +385,8 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( mfma_variant: ScaledMMAType = ScaledMMAType.F32_16x16x128_F8F6F4, a_address_space: tkl.AddressSpace = SHARED_ADDRESS_SPACE, b_address_space: tkl.AddressSpace | None = None, + output_dtype=tkl.f32, + transpose_output: bool = False, ): """Return a tagged MXFP4 scaled GEMM kernel with preshuffled B and B_scale. @@ -363,6 +402,7 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( mfma_variant: Scaled MMA instruction type. a_address_space: Address space for A. b_address_space: Address space for B. + transpose_output: If True, output memory is [N, M] instead of [M, N]. Returns: (kernel_function, WaveCompileOptions) """ @@ -374,6 +414,8 @@ def get_tagged_mxfp4_gemm_preshuffle_scales_and_B( a_address_space, b_address_space, b_preshuffled=True, + output_dtype=output_dtype, + transpose_output=transpose_output, )