Skip to content

[TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api #542

Open
VeeraRajasekhar wants to merge 22 commits intodevfrom
veergopu/smallseq-cross-attn-new-backend
Open

[TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api #542
VeeraRajasekhar wants to merge 22 commits intodevfrom
veergopu/smallseq-cross-attn-new-backend

Conversation

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor

… backend

Refactor the ROCm small-sequence attention path so it is a first-class backend instead of branching from the generic CK fused-attention entry: add NVTE entry points and ROCm implementations under fused_attn_rocm, CMake wiring, and public declarations in fused_attn.h.

Rename the small-seq sources to fused_attn_small_seq.* so filenames match the new API. Extend kernel dispatch to head sizes 128, 256, and 512.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

VeeraRajasekhar and others added 11 commits February 24, 2026 19:12
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..

- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
  fused_attn_rocm/: declarations and implementation adapted from
  varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
  grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
  max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.

- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
  transformer_engine/common/CMakeLists.txt.

- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
  max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
  == 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
  max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
  h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
  call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
  count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
  output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
  matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
  kernel expects sequence-level batch).

- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
  (workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
  max_seqlen_kv; on real run call get_runtime_max_seqlen then
  fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
  get_runtime_max_seqlen, workspace size, and small-seq bwd.

- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
  path (forward write, backward read);

- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
  kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
  q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
  buffer matches C++ attention-weights convention.

- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
  (parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
  SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
  C++.
- tests/jax: CK small-seq tests use fixture to set/restore
  NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing
  cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq =
  max_seqlen_q for THD else 2.
- JAX attention.py: THD softmax shape/dtype uses small-seq path only when
  env=1, else original layout
- JAX attention.cpp: Added env guard
- fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for
  fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
… backend

Refactor the ROCm small-sequence attention path so it is a first-class backend
instead of branching from the generic CK fused-attention entry: add NVTE
entry points and ROCm implementations under fused_attn_rocm, CMake wiring,
and public declarations in fused_attn.h.

Rename the small-seq sources to fused_attn_small_seq.* so filenames match the
new API. Extend kernel dispatch to head sizes 128, 256, and 512.
@VeeraRajasekhar VeeraRajasekhar self-assigned this Apr 15, 2026
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp Outdated
Comment thread transformer_engine/common/fused_attn_rocm/fused_attn_small_seq.cpp Outdated
Wire the explicit small-sequence path in JAX csrc.

- extensions.h: declare GetSmallSeqAttn{Forward,Backward}WorkspaceSizes,
  SmallSeqAttn{Forward,Backward}FFI handlers, and XLA_FFI_DECLARE_HANDLER_SYMBOL
  exports for XLA registration.

- attention.cpp (USE_ROCM): add PrepareSmallSeqAttnForwardAuxTensors /
  PrepareSmallSeqAttnBackwardAuxTensors to build NVTETensorPack for small-seq
  (softmax slot = attention-weights buffer layout, RNG slot per fused API
  contract); memset ragged output/softmax aux as needed for THD.

- GetSmallSeqAttnForwardWorkspaceSizes / GetSmallSeqAttnBackwardWorkspaceSizes:
  gate on nvte_is_small_seq_attn_supported, return minimal forward workspace and
  nvte_fused_attn_small_seq_bwd_workspace_size-backed backward scratch.

- SmallSeqAttnForwardImpl / SmallSeqAttnBackwardImpl: reuse FUSED_ATTN_IMPL_COMMON_BLOCK
  for THD cu_seqlens/offsets, call nvte_fused_attn_small_seq_fwd / _bwd.

- SmallSeqAttnForwardFFI / SmallSeqAttnBackwardFFI + XLA_FFI_DEFINE_HANDLER_SYMBOL:
  mirror FusedAttn*FFI attribute unpacking so JAX can invoke the dedicated backend.
Add SmallSeqAttnFwdPrimitive / SmallSeqAttnBwdPrimitive in cpp_extensions/attention.py
so JAX compiles and lowers to the dedicated small-seq FFI without
nvte_get_fused_attn_backend or generic fused-attn workspace probing.

- abstract: HIP-only; validate THD_THD_THD, no bias/dropout and head dims; softmax_aux shape (*batch, heads, q, min(kv,16))
  in Q dtype; workspace from get_small_seq_attn_{fwd,bwd}_workspace_sizes.

- lowering: ffi_lowering to te_small_seq_attn_{forward,backward}_ffi with the
  same flattened attrs pattern as generic fused attention.

- fused_attn_small_seq_fwd / fused_attn_small_seq_bwd: thin bind helpers;
  export via __all__. register_primitive for both primitives.
Expose ROCm small-sequence cross-attention at the JAX layer next to fused_attn.

- Custom primitive _fused_attn_small_seq with forward/backward rules calling
  cpp_extensions fused_attn_small_seq_fwd/bwd (tex.*).

- fused_attn_small_seq(): user entry point taking (q,k,v), bias slot,
  SequenceDescriptor, seed, mask/layout/scaling/dropout/is_training — targets
  the explicit small-seq backend.
@VeeraRajasekhar VeeraRajasekhar marked this pull request as ready for review April 24, 2026 05:38
@VeeraRajasekhar VeeraRajasekhar changed the base branch from veergopu/fused-varlen-ck-smallseq-integration to dev April 24, 2026 05:39
@VeeraRajasekhar VeeraRajasekhar added the ci-level 3 CI test level 3 label Apr 25, 2026
@VeeraRajasekhar VeeraRajasekhar force-pushed the veergopu/smallseq-cross-attn-new-backend branch from 8ace430 to 09ab963 Compare April 25, 2026 02:10
"the F16_arbitrary_seqlen backend."
)

def _setup_thd_segments_small_seq(self, generate_random_segment_ids):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since our customer already mentioned that their sq is 1 for each segment, without padding at all, and their sq_kv <=16 including padding, we can create our own small_seq input generation to separate with the non-small-seq tests

const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv);
const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded);
const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded);
const Tensor *input_rng_state = convertNVTETensorCheck(rng_state);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even support dropout with this rng_state?

Comment on lines +1095 to +1096
workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype);
NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small.");
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: now we don't have mixed old ck + new small_seq flow, will we still get more workspace than we need?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can we have their repo as a 3rdparty then reference their .h. No need to do this now

Comment on lines +787 to +793
size_t fused_attn_small_seq_bwd_workspace_size(size_t b,
size_t h_q,
size_t max_seqlen_kv,
DType dtype) {
constexpr size_t elt_size = 2u; // BF16 and FP16 are 2 bytes
return b * h_q * 1 * std::min(max_seqlen_kv, size_t(16)) * elt_size;
}
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a comment on what's the actual bwd workspace for, namely, why it's bhq1std::min(max_seqlen_kv, size_t(16))*elt_size?
Probably it's because it's for dS or dP?

del config, result_infos
q_spec = get_padded_spec(arg_infos[0])
out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec))
softmax_aux_sharding = NamedSharding(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how sharding rule need to be changed

# NVTE uses b = cu_seqlens_q.shape[0] - 1 (one packed segment per slot), not
# reduce(batch_shape). E.g. seqpack with max_seqlen_q>1 yields cu length
# batch*segments+1 while Q still has leading logical batch only.
small_seq_workspace_batch = q_seqlen_or_cu_seqlen_aval.shape[0] - 1
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think at jax side, we don't need to read the actual value of cu_seqlen's.

NVTE_CHECK(bias_batch == 0 && bias_heads == 0,
"SmallSeqAttnForwardImpl: bias not supported for small-seq.");

auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we support bias?

PrepareSmallSeqAttnForwardAuxTensors(&aux_output_tensors, input_batch, attn_heads, q_max_seqlen,
kv_max_seqlen, dtype, softmax_aux, rng_state);

auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we know it's dummy and our nvte_fused_attn_small_seq_fwd do not use it actually, we can skip the passing

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as other tensors not used, for example rng_state


auto bias_tensor = TensorWrapper(bias, bias_shape, dtype);

if (is_ragged) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently we only support is_ragged. Later we will support BSHD. Please add a TODO comment. Otherwise, we should include all following things into this if branch

@Micky774 Micky774 changed the title [TE] Phase 2 of Sciforium cross-attn integration: a separate cpp backend and a new jax api [TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api Apr 27, 2026
@Micky774
Copy link
Copy Markdown
Contributor

This branch is currently missing commits 7faf099f and 5f592b1a from the original PR.

Copy link
Copy Markdown
Contributor

@Micky774 Micky774 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, if I'm understanding this correctly, the only way to utilize the new backend is via the Python-level fused_attn_small_seq entrypoint -- is this desired? We have no means of automatically dispatching to this backend through existing API.

Comment on lines +928 to +931
const size_t runtime_s_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream));
const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This calls ck_fused_attn::get_runtime_max_seqlen unconditionally and breaks AOTriton-only builds.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the code in this file unconditionally pulls in tex.fused_attn_small_seq_{f,b}wd -- we should guard these funcs.

outer_primitive = None

@staticmethod
def abstract(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we also proactively guard against GQA/MQA here?

)

@staticmethod
def impl(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_fix_len_take, convert_to_2d, seqlen/offset processing are all copied between FusedAttn*Primitive.impl and SmallSeqAttn*Primitive.impl -- let's try to reuse them.

Comment on lines +380 to +386
void PrepareSmallSeqAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch,
const size_t attn_heads, const size_t q_max_seqlen,
const size_t kv_max_seqlen, DType dtype, void *softmax_buf,
void *rng_state_buf) {
PrepareSmallSeqAttnForwardAuxTensors(tensor_pack, input_batch, attn_heads, q_max_seqlen,
kv_max_seqlen, dtype, softmax_buf, rng_state_buf);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wrapper is trivial, can we forward directly?

b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream));
const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen(
b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream));
if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At this API level we use NVTE_LOG_FUSED_ATTN_CONFIG

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

VeeraRajasekhar commented Apr 27, 2026

This branch is currently missing commits 7faf099f and 5f592b1a from the original PR.

Yes @Micky774, this was intentional. Those missing commits are relevant to some of the corner cases which we decided to not cover.

@VeeraRajasekhar
Copy link
Copy Markdown
Contributor Author

Currently, if I'm understanding this correctly, the only way to utilize the new backend is via the Python-level fused_attn_small_seq entrypoint -- is this desired? We have no means of automatically dispatching to this backend through existing API.

We cannot guarantee two conditions i.e., to determine max_seq_len_q = 1 and 2<=max_seq_len_kv<=16, which are known at runtime. i.e., these two tests , even though these two donot satisfy the conditions but they will be satisfied on runtime, this is what is blocking us from automatically deploying to this backend. that being said we will throw assertion errors during runtime if these are not met.

@Micky774
Copy link
Copy Markdown
Contributor

We cannot guarantee two conditions i.e., to determine max_seq_len_q = 1 and 2<=max_seq_len_kv<=16, which are known at runtime. i.e., these two tests , even though these two donot satisfy the conditions but they will be satisfied on runtime, this is what is blocking us from automatically deploying to this backend. that being said we will throw assertion errors during runtime if these are not met.

If we're providing a static entry point anyways, I suspect this means that we expect users to know ahead of time whether their data complies with such a backend. In which case, even mediating with an environment variable should be feasible right? I'm just trying to understand the actual user experience we're trying to support here.

@Micky774
Copy link
Copy Markdown
Contributor

Alternatively, even just adding max_segment_seqlen_q, max_segment_seqlen_kv as helper descriptors to the SequenceDescriptor spec would suffice to enable an opt-in dispatch that requires user-ownership of data viability.

@wangye805
Copy link
Copy Markdown
Collaborator

Alternatively, even just adding max_segment_seqlen_q, max_segment_seqlen_kv as helper descriptors to the SequenceDescriptor spec would suffice to enable an opt-in dispatch that requires user-ownership of data viability.

Previously I discussed this with Veera and also our customers. There are two three things prevent us using a unified api as previous ck flow:
1). Due to jax's sequence packing, segment length could be quite different from the max_seqlen_q/kv
2). Due to the jit property, runtime sequence length cannot be obtained during the workspace buffer allocation phase
3). This set of new kernels are basically fused-attn which saves P = softmax(S) of shape [b, h, s_q, s_kv], which is different from ck flow which saves log(row sum of exp (S))

So any good ideas you have can work those around?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci-level 3 CI test level 3

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants