Skip to content

[ROCm] add the bias all row -inf support for jax unfused-attn#556

Open
wangye805 wants to merge 4 commits intodevfrom
yewang12/bias_all_row_inf_check
Open

[ROCm] add the bias all row -inf support for jax unfused-attn#556
wangye805 wants to merge 4 commits intodevfrom
yewang12/bias_all_row_inf_check

Conversation

@wangye805
Copy link
Copy Markdown
Collaborator

Description

CK fixed the corner case with all row -inf in bias in ROCm/composable_kernel#3326. This also aligns with NV upstream. We added the corresponding jax pytest by modifying the jax unfused-attn.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

This concludes the https://github.com/ROCm/frameworks-internal/issues/14668 and https://ontrack-internal.amd.com/browse/SWDEV-561757

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Comment thread tests/jax/test_fused_attn.py
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
Comment thread tests/jax/test_fused_attn.py Outdated
@Micky774
Copy link
Copy Markdown
Contributor

Looks like CI is failing for incorrect broadcasting shapes

@wangye805 wangye805 requested a review from Micky774 April 21, 2026 21:05
),
],
)
def test_backward_bias_all_neg_inf(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be limited to CK backend then

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Added a ck backend guard

Comment thread tests/jax/test_fused_attn.py
@Micky774
Copy link
Copy Markdown
Contributor

CI failure is unrelated and we have a pass on mi35x

Comment thread tests/jax/test_fused_attn.py
Copy link
Copy Markdown
Contributor

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left an optional suggestion, otherwise LGTM.

Co-authored-by: Meekail Zain <34613774+Micky774@users.noreply.github.com>
@wangye805 wangye805 requested a review from ipanfilo April 28, 2026 16:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants