Skip to content

Hoist loop-variant affine index ops out of scf.for during codegen#1283

Open
willghatch wants to merge 6 commits intomainfrom
users/willghatch/linearize-loop-affine-apply
Open

Hoist loop-variant affine index ops out of scf.for during codegen#1283
willghatch wants to merge 6 commits intomainfrom
users/willghatch/linearize-loop-affine-apply

Conversation

@willghatch
Copy link
Copy Markdown
Contributor

When dimensions are dynamic, the linearized index expressions passed to affine.apply contain the loop induction variable, preventing LICM from moving them out of the loop body. Split each such expression into base + iv * stride, generate the base before the loop with gen_sympy_index_hoisted, and emit the IV-dependent part as plain arith ops that downstream LICM can reason about independently.

The key test is lit_tests/kernel/wave/linearize_loop_affine_maps.py which shows that no affine.apply ops occur inside a loop body.

Key changes:

  • Add gen_sympy_index_hoisted (emitter.py): decomposes a sympy expression into a loop-invariant base (hoisted) and an IV-scaled offset (emitted in-place), with reconstruction guards that fall back to the original gen_sympy_index when the split is not provably safe.

  • Refactor read_write.py: extract _get_enclosing_scf_for (walks parent regions), _hoist_before_loop, _iv_context, _gen_linear_index_offset, _hoist_dst_indices, and _linearize_memref_hoisted to replace ad-hoc hoisting scattered across _handle_read_linear_index and handle_gather_to_lds.

  • Rework _build_mask to lower each bound condition individually with hoisting support, replacing the previous functools.reduce over a single sympy And expression.

  • Add linearize_loop_affine_maps.py lit test verifying no affine.apply ops remain inside the pipelined scf.for body for two wave shapes.

  • Update CHECK patterns in 5 existing lit tests to reflect the new code shape (fewer affine_map parameters, hoisted base ops before the loop, arith.muli/addi inside the loop).

Made-with: Cursor

@willghatch willghatch requested a review from panditsa April 8, 2026 21:24
When dimensions are dynamic, the linearized index expressions passed to
affine.apply contain the loop induction variable, preventing LICM from
moving them out of the loop body.  Split each such expression into
`base + iv * stride`, generate the base before the loop with
gen_sympy_index_hoisted, and emit the IV-dependent part as plain arith
ops that downstream LICM can reason about independently.

The key test is `lit_tests/kernel/wave/linearize_loop_affine_maps.py` which
shows that no `affine.apply` ops occur inside a loop body.

Key changes:

- Add gen_sympy_index_hoisted (emitter.py): decomposes a sympy
  expression into a loop-invariant base (hoisted) and an IV-scaled
  offset (emitted in-place), with reconstruction guards that fall back
  to the original gen_sympy_index when the split is not provably safe.

- Refactor read_write.py: extract _get_enclosing_scf_for (walks parent
  regions), _hoist_before_loop, _iv_context, _gen_linear_index_offset,
  _hoist_dst_indices, and _linearize_memref_hoisted to replace ad-hoc
  hoisting scattered across _handle_read_linear_index and
  handle_gather_to_lds.

- Rework _build_mask to lower each bound condition individually with
  hoisting support, replacing the previous functools.reduce over a
  single sympy And expression.

- Add linearize_loop_affine_maps.py lit test verifying no affine.apply
  ops remain inside the pipelined scf.for body for two wave shapes.

- Update CHECK patterns in 5 existing lit tests to reflect the new
  code shape (fewer affine_map parameters, hoisted base ops before the
  loop, arith.muli/addi inside the loop).

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
Add iv_offset caching to deduplicate iv*stride products across reads
that share the same stride coefficient within a loop body.  Add
gen_sympy_index_no_affine for computing expressions in-place using
arith ops (no affine.apply) inside loop bodies.  Use no-affine mode
for the guard condition in _compute_branchless_valid_bytes.  Also add
use_affine parameter to gen_sympy_index for local override of the
global affine expression mode.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
Copy link
Copy Markdown
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Arguably, this should just be an "advanced LICM" MLIR pass rather than extra ad hoc logic in whatever emits MLIR. It is also unclear to me what is the value of having affine.apply to start with, expression simplification should have happened in sympy sufficiently.

@willghatch
Copy link
Copy Markdown
Contributor Author

This approach was mostly taken as an expedient quick path to getting the mxfp4 assembly output that we wanted, working around limitations of our LICM pass. I agree that it is not the optimal way to do this.

_verify_stride_on_original was rejecting valid stride decompositions
for B-data and B-scale reads because it probed with arbitrary values
that violated divisibility assumptions (e.g. K not a multiple of 256).
This made floor/Mod terms produce different results per IV step,
causing the verifier to think the stride was non-constant.

Pass div_fwd substitutions into the verifier so probed values always
satisfy the constraints.  This eliminates all residual affine.apply
ops for B-data and B-scale index offsets in the loop body (12 per
test for 1x4, 28 for 2x2).

The only remaining affine.apply in the loop body is a single
bounds-check computation unrelated to data/scale index offsets.

Made-with: Cursor
@willghatch willghatch force-pushed the users/willghatch/linearize-loop-affine-apply branch from bd8d704 to eb5b1f8 Compare April 9, 2026 20:00
The affine-apply hoisting optimization reduces register counts with a
locally-built LLVM, but CI's pinned LLVM version produces different
register allocation.  The water backend waitcounts changed as expected,
but register counts remain at main-branch levels (172 VGPRs, 61 SGPRs)
in CI.

Change Details:

- Revert water backend vgpr_count from 164 back to 172 and sgpr_count from 57 back to 61 to match CI-observed values

Made-with: Cursor
Change Details:

- Skip _build_mask and arith_d.select on gather_to_lds LINEAR_INDEX and N-D paths when valid_bytes_override is set from g2s_guard, so hardware num_records clamping handles OOB and VGPR pressure drops.
- Align no_masked_load_store_ops with ogsplit for buffer ops + asm backend so masked loads are allowed and lowering can use hardware OOB where appropriate.

Made-with: Cursor
@github-actions
Copy link
Copy Markdown

Water Code Coverage

Filename                                                           Functions  Missed Functions  Executed       Lines      Missed Lines     Cover    Branches   Missed Branches     Cover
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
lib/Transforms/MemrefDecomposition.cpp                                    28                 0   100.00%         600                49    91.83%         104                46    55.77%
lib/Transforms/AllocToAlloca.cpp                                           2                 0   100.00%          17                 0   100.00%           0                 0         -
lib/Transforms/CheckStaticAssertions.cpp                                   2                 0   100.00%          22                 1    95.45%           8                 4    50.00%
lib/Transforms/GPUModuleToBinary.cpp                                      19                 5    73.68%         339               115    66.08%         128                57    55.47%
lib/Transforms/DropTransformOps.cpp                                        2                 0   100.00%          16                 0   100.00%           2                 0   100.00%
lib/Transforms/GPUToGPURuntime.cpp                                        14                 0   100.00%         298                23    92.28%          40                17    57.50%
lib/Transforms/SLPVectorizer.cpp                                          61                 3    95.08%        1065                95    91.08%         558               164    70.61%
lib/Transforms/AccessCheckers.cpp                                         35                 1    97.14%         446                40    91.03%         124                30    75.81%
lib/Transforms/AssembleISA.cpp                                             4                 1    75.00%          30                 2    93.33%           2                 1    50.00%
lib/Dialect/Wave/Transforms/LoweringPatterns.cpp                          48                 2    95.83%         963               147    84.74%         272                82    69.85%
lib/Dialect/Wave/Transforms/PropagateDefaultsFromConstraints.cpp           3                 3     0.00%          35                35     0.00%          12                12     0.00%
lib/Dialect/Wave/Transforms/TypeConverter.cpp                              7                 2    71.43%          96                26    72.92%          32                17    46.88%
lib/Dialect/Wave/Transforms/LowerReadWriteOps.cpp                         10                 0   100.00%         219                15    93.15%          54                10    81.48%
lib/Dialect/Wave/Transforms/DetectNormalForms.cpp                          4                 0   100.00%          48                 0   100.00%           8                 0   100.00%
lib/Dialect/Wave/Transforms/ExpandVariadicReductions.cpp                   2                 0   100.00%          24                 1    95.83%           6                 1    83.33%
lib/Dialect/Wave/Transforms/InferTypes.cpp                                62                11    82.26%         920                94    89.78%         362               153    57.73%
lib/Dialect/Wave/Transforms/LowerWaveToMLIR.cpp                            5                 0   100.00%         129                 1    99.22%          16                 2    87.50%
lib/Dialect/Wave/Transforms/InferIndexExprs.cpp                            3                 0   100.00%          34                 1    97.06%           8                 1    87.50%
lib/Dialect/Wave/Transforms/Utils.cpp                                      4                 0   100.00%          64                 5    92.19%          20                 4    80.00%
lib/Dialect/Wave/Transforms/ResolveDistributedAllocations.cpp              4                 0   100.00%          97                13    86.60%          22                10    54.55%
lib/Dialect/Wave/IR/IndexExprInference.cpp                               173                17    90.17%        3269               286    91.25%        1368               490    64.18%
lib/Dialect/Wave/IR/WaveOps.cpp                                           91                10    89.01%        1596               194    87.84%         712               138    80.62%
lib/Dialect/Wave/IR/WaveAttrs.cpp                                         77                 5    93.51%         985                84    91.47%         434                64    85.25%
lib/Dialect/Wave/IR/IndexExpr.cpp                                         10                 0   100.00%         117                 1    99.15%          24                 3    87.50%
lib/Dialect/Wave/IR/WaveDialect.cpp                                       15                 0   100.00%         498                 9    98.19%         174                 5    97.13%
lib/Dialect/Wave/IR/WaveTypes.cpp                                          9                 1    88.89%          75                 8    89.33%          18                 3    83.33%
lib/Dialect/Wave/IR/WaveInterfaces.cpp                                    37                 0   100.00%         661                42    93.65%         322                44    86.34%
lib/Dialect/Wave/IR/WaveUtils.cpp                                         23                 0   100.00%         217                 8    96.31%          84                13    84.52%
lib/Dialect/NormalForm/Transforms/LowerNormalFormModule.cpp                3                 0   100.00%          34                 6    82.35%           8                 2    75.00%
lib/Dialect/NormalForm/IR/NormalFormDialect.cpp                            1                 0   100.00%           6                 0   100.00%           0                 0         -
lib/Dialect/NormalForm/IR/NormalFormOps.cpp                               12                 0   100.00%         201                 9    95.52%          58                 7    87.93%
lib/Pipelines/Pipelines.cpp                                                2                 0   100.00%          27                 0   100.00%           0                 0         -
lib/Analysis/InUseForSpeculation.cpp                                      12                 1    91.67%         142                 8    94.37%          32                 4    87.50%
include/water/Dialect/Wave/Transforms/LoweringPatterns.h                   1                 0   100.00%           3                 0   100.00%           0                 0         -
include/water/Dialect/Wave/IR/IndexExpr.h                                  1                 0   100.00%          10                 0   100.00%           2                 0   100.00%
include/water/Dialect/Wave/IR/WaveInterfaces.h                            40                 3    92.50%         159                 8    94.97%           8                 2    75.00%
include/water/Dialect/Wave/IR/WaveTypes.h                                  1                 0   100.00%           5                 0   100.00%           4                 0   100.00%
include/water/Dialect/Wave/IR/WaveUtils.h                                  1                 0   100.00%           5                 0   100.00%           4                 1    75.00%
include/water/Dialect/Wave/IR/WaveAttrs.h                                  4                 0   100.00%          16                 0   100.00%           0                 0         -
include/water/Dialect/NormalForm/IR/NormalFormInterfaces.h                 1                 1     0.00%           4                 4     0.00%           0                 0         -
include/water/Analysis/InUseForSpeculation.h                              12                 3    75.00%          39                17    56.41%          16                10    37.50%
-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
TOTAL                                                                    845                69    91.83%       13531              1347    90.05%        5046              1397    72.31%

Download full HTML report

@harsh-nod harsh-nod self-requested a review April 16, 2026 23:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants