fix(pass): Skip R-seq InOut promotion for disjoint variable-offset stores#1232
fix(pass): Skip R-seq InOut promotion for disjoint variable-offset stores#1232Crystal-wzy wants to merge 1 commit intohw-native-sys:mainfrom
Conversation
📝 WalkthroughWalkthroughThis PR extends the call direction derivation pass with expression-variable dependency analysis to detect when callees write to Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~22 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces an optimization to avoid promoting OutputExisting to InOut in sequential loops when writes are disjoint, using a new CalleeHasOnlyVariableOffsetStores check and a caching mechanism. The review feedback points out that the analysis is currently insufficient as it uses a shallow traversal that misses nested stores and fails to verify that offset arguments are variant within the loop. Additionally, the ExprReferencesAnyOf helper should be extended to support more IR node types, and further testing is recommended to ensure the soundness of the disjointness logic.
| auto stmts = transform_utils::FlattenToStmts(callee->body_); | ||
| for (const auto& stmt : stmts) { | ||
| auto assign = As<AssignStmt>(stmt); | ||
| if (!assign) continue; | ||
| auto call = As<Call>(assign->value_); | ||
| if (!call || !call->op_) continue; | ||
|
|
||
| if (call->op_->name_ == "tile.store" && call->args_.size() >= 3) { | ||
| auto target_var = AsVarLike(call->args_[2]); | ||
| if (target_var && aliases.count(target_var.get())) { | ||
| found_store = true; | ||
| aliases.insert(assign->var_.get()); | ||
| if (!ExprReferencesAnyOf(call->args_[1], other_params)) { | ||
| all_variable = false; | ||
| } | ||
| } | ||
| } | ||
| if (call->op_->name_ == "tensor.assemble" && !call->args_.empty()) { | ||
| auto target_var = AsVarLike(call->args_[0]); | ||
| if (target_var && aliases.count(target_var.get())) { | ||
| aliases.insert(assign->var_.get()); | ||
| } | ||
| } | ||
| } |
There was a problem hiding this comment.
This analysis uses FlattenToStmts, which only inspects top-level statements. This is insufficient for resolving call-site parameter directions as nested stores or assignments will be missed. Since effective directions for Spmd and Group types must be computed by inspecting inner kernel calls for accurate dependency tracking, a full IRVisitor should be used instead of a shallow traversal.
References
- When resolving call-site parameter directions, compute effective directions for both Spmd and Group function types by inspecting their inner kernel calls. These function types act as wrappers and their formal parameter directions may not reflect the true data flow, which is crucial for dependency tracking.
| if (cache_it != fn_cache.end()) { | ||
| disjoint = cache_it->second; | ||
| } else { | ||
| disjoint = CalleeHasOnlyVariableOffsetStores(callee, i); |
There was a problem hiding this comment.
The CalleeHasOnlyVariableOffsetStores check only ensures that the callee's store offsets depend on its parameters. This is insufficient to guarantee disjoint writes at the call site. If the caller passes a loop-invariant value to the offset parameter, the writes will overlap across iterations. To safely skip InOut promotion, the pass should also verify that the argument passed to the offset parameter is variant with respect to the sequential loop.
| bool ExprReferencesAnyOf(const ExprPtr& expr, const std::unordered_set<const Var*>& vars) { | ||
| if (!expr) return false; | ||
| if (auto var = As<Var>(expr)) return vars.count(var.get()) > 0; | ||
| if (auto tuple = As<MakeTuple>(expr)) { | ||
| for (const auto& e : tuple->elements_) { | ||
| if (ExprReferencesAnyOf(e, vars)) return true; | ||
| } | ||
| } | ||
| return false; | ||
| } |
There was a problem hiding this comment.
The current implementation of ExprReferencesAnyOf only handles Var and MakeTuple nodes. To handle common offset expressions (e.g., offset + 1), it should be extended to handle BinaryExpr, UnaryExpr, and TupleGetItemExpr, ensuring consistent traversal across IR node types as required by repository standards for collection utilities.
bool ExprReferencesAnyOf(const ExprPtr& expr, const std::unordered_set<const Var*>& vars) {
if (!expr) return false;
if (auto var = As<Var>(expr)) return vars.count(var.get()) > 0;
if (auto tuple = As<MakeTuple>(expr)) {
for (const auto& e : tuple->elements_) {
if (ExprReferencesAnyOf(e, vars)) return true;
}
}
if (auto bin = As<BinaryExpr>(expr)) {
return ExprReferencesAnyOf(bin->left_, vars) || ExprReferencesAnyOf(bin->right_, vars);
}
if (auto un = As<UnaryExpr>(expr)) {
return ExprReferencesAnyOf(un->operand_, vars);
}
if (auto tgi = As<TupleGetItemExpr>(expr)) {
return ExprReferencesAnyOf(tgi->tuple_, vars);
}
return false;
}References
- When adding support for a new IR node type to a transformation pass, ensure all relevant traversal and collection utilities within that pass are updated to handle the new type consistently.
| out = passes.derive_call_directions()(Prog) | ||
| calls = [c for c in _user_calls(out) if c.op.name == "kernel"] | ||
| assert len(calls) == 1 | ||
| assert _dirs(calls[0]) == [ir.ArgDirection.Input, ir.ArgDirection.Scalar, ir.ArgDirection.OutputExisting] |
There was a problem hiding this comment.
Consider adding a test case where a loop-invariant value is passed to the offset parameter. This would demonstrate the soundness issue regarding disjointness at the call site. Per repository guidelines, when testing a pass with a known bug, the expected IR should match the actual buggy output while the issue is tracked.
References
- When testing a pass with a known bug, the 'Expected' IR should match the actual (buggy) output of the pass. The bug itself should be tracked in a separate issue.
be65189 to
4f7be22
Compare
…ores ## Summary - Add `DisjointStoreVisitor` to detect when a callee writes to an Out parameter exclusively via `tile.store` with offsets that depend on other function parameters (position-dependent, disjoint writes) - At the call site, verify the corresponding arguments are loop-variant (reference a sequential loop induction variable) before keeping the original `OutputExisting` direction instead of promoting to `InOut` - Track sequential loop variables in `CallDirectionMutator` and cache per-callee analysis results in `offset_param_cache_` to avoid redundant visitor traversals - Add `ExprReferencesAnyOf` utility for checking transitive Var references across BinaryExpr, UnaryExpr, TupleGetItemExpr, Call, and MakeTuple ## Testing - [x] Added `test_out_param_variable_offset_store_in_seq_loop_not_promoted` verifying disjoint stores keep `OutputExisting` - [x] Added `test_out_param_invariant_offset_in_seq_loop_promoted` verifying loop-invariant offsets still promote to `InOut`
Summary
DisjointStoreVisitorto detect when a callee writes to an Outparameter exclusively via
tile.storewith offsets that depend on otherfunction parameters (position-dependent, disjoint writes)
(reference a sequential loop induction variable) before keeping the
original
OutputExistingdirection instead of promoting toInOutCallDirectionMutatorand cacheper-callee analysis results in
offset_param_cache_to avoid redundantvisitor traversals
ExprReferencesAnyOfutility for checking transitive Var referencesacross BinaryExpr, UnaryExpr, TupleGetItemExpr, Call, and MakeTuple
Testing
test_out_param_variable_offset_store_in_seq_loop_not_promotedverifying disjoint stores keep
OutputExistingtest_out_param_invariant_offset_in_seq_loop_promotedverifying loop-invariant offsets still promote to
InOut