Skip to content

Add stream-K MXFP4 GEMM kernel and tests#1296

Open
willghatch wants to merge 2 commits intomainfrom
users/willghatch/streamk-mxfp4
Open

Add stream-K MXFP4 GEMM kernel and tests#1296
willghatch wants to merge 2 commits intomainfrom
users/willghatch/streamk-mxfp4

Conversation

@willghatch
Copy link
Copy Markdown
Contributor

This is a functional streamk for mxfp4.

Persistent-kernel approach: a fixed number of CTAs iterate over work units, each covering a K-range of an output tile. Uses atomic_add for accumulation (caller must zero-init C).

Additionally, fix atomic_add codegen to use write-side index mapping

handle_generic_atomic called transform_index_on_mapping with the default is_read=True, so for atomic_add ops with an IndexMapping the input side of the mapping was used to compute the destination address. When the mapping carries write-side dynamic offsets (e.g. CTA_M_OFFSET, CTA_N_OFFSET in stream-K kernels) those offsets were silently ignored, causing all CTAs to write to the same tile.

Pass is_read=False so the output side of the mapping is used instead.

Made-with: Cursor

This is a functional streamk for mxfp4.

Persistent-kernel approach: a fixed number of CTAs iterate over work
units, each covering a K-range of an output tile.  Uses atomic_add for
accumulation (caller must zero-init C).

Additionally, fix atomic_add codegen to use write-side index mapping

handle_generic_atomic called transform_index_on_mapping with the
default is_read=True, so for atomic_add ops with an IndexMapping
the *input* side of the mapping was used to compute the destination
address.  When the mapping carries write-side dynamic offsets (e.g.
CTA_M_OFFSET, CTA_N_OFFSET in stream-K kernels) those offsets were
silently ignored, causing all CTAs to write to the same tile.

Pass is_read=False so the output side of the mapping is used instead.

Made-with: Cursor
Signed-off-by: William G Hatch <william@hatch.uno>
@willghatch
Copy link
Copy Markdown
Contributor Author

@harsh-nod As mentioned, this is one of the splitk/streamk branches that was waiting for the splitk PR.

The mapping `inputs` field describes the target memory location for atomic
operations (consistent with how read mappings use `inputs` to specify the
source location).  Using `is_read=False` (output mapping) broke `atomic_min`
which relies on `inputs` for the reduction target address.

Change Details:

- handlers.py: Revert `is_read=False` to `is_read=True` so atomic ops use
  `map_input_indices`, consistent with reads.
- tagged_mxfp4_gemm.py: Fix `c_write_mapping` to put the CTA offsets in
  `inputs` (the target location) and bare iterators in `outputs`, matching
  the mapping convention used everywhere else.

Signed-off-by: William G Hatch <william@hatch.uno>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant