Add stream-K MXFP4 GEMM kernel and tests#1296
Open
willghatch wants to merge 2 commits intomainfrom
Open
Conversation
This is a functional streamk for mxfp4. Persistent-kernel approach: a fixed number of CTAs iterate over work units, each covering a K-range of an output tile. Uses atomic_add for accumulation (caller must zero-init C). Additionally, fix atomic_add codegen to use write-side index mapping handle_generic_atomic called transform_index_on_mapping with the default is_read=True, so for atomic_add ops with an IndexMapping the *input* side of the mapping was used to compute the destination address. When the mapping carries write-side dynamic offsets (e.g. CTA_M_OFFSET, CTA_N_OFFSET in stream-K kernels) those offsets were silently ignored, causing all CTAs to write to the same tile. Pass is_read=False so the output side of the mapping is used instead. Made-with: Cursor Signed-off-by: William G Hatch <william@hatch.uno>
Contributor
Author
|
@harsh-nod As mentioned, this is one of the splitk/streamk branches that was waiting for the splitk PR. |
The mapping `inputs` field describes the target memory location for atomic operations (consistent with how read mappings use `inputs` to specify the source location). Using `is_read=False` (output mapping) broke `atomic_min` which relies on `inputs` for the reduction target address. Change Details: - handlers.py: Revert `is_read=False` to `is_read=True` so atomic ops use `map_input_indices`, consistent with reads. - tagged_mxfp4_gemm.py: Fix `c_write_mapping` to put the CTA offsets in `inputs` (the target location) and bare iterators in `outputs`, matching the mapping convention used everywhere else. Signed-off-by: William G Hatch <william@hatch.uno>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This is a functional streamk for mxfp4.
Persistent-kernel approach: a fixed number of CTAs iterate over work units, each covering a K-range of an output tile. Uses atomic_add for accumulation (caller must zero-init C).
Additionally, fix atomic_add codegen to use write-side index mapping
handle_generic_atomic called transform_index_on_mapping with the default is_read=True, so for atomic_add ops with an IndexMapping the input side of the mapping was used to compute the destination address. When the mapping carries write-side dynamic offsets (e.g. CTA_M_OFFSET, CTA_N_OFFSET in stream-K kernels) those offsets were silently ignored, causing all CTAs to write to the same tile.
Pass is_read=False so the output side of the mapping is used instead.
Made-with: Cursor