[TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api #542
[TE] Phase 2 of small-seq cross-attn integration: a separate cpp backend and a new jax api #542VeeraRajasekhar wants to merge 22 commits intodevfrom
Conversation
Integrate the CK team's unfused variable-length attention HIP kernels from
varlen_attn/ into Transformer Engine's ROCm fused-attn path as a specialized
path for specialized cross-attention (Q length 1, KV length 2-16, large
batch)..
- Add fused_attn_smallseq.hpp and fused_attn_smallseq.cpp under
fused_attn_rocm/: declarations and implementation adapted from
varlen_attn/attn_fwd.cpp and attn_bwd.cpp (scores, mask+softmax, output;
grad_V, grad_attn, softmax bwd, grad_Q/grad_K). Runtime dispatch over
max_seqlen_kv in {2,4,6,8,12,16}, head_dim 128, BF16.
- Add fused_attn_smallseq.cpp to the ROCm fused-attn build in
transformer_engine/common/CMakeLists.txt.
- In fused_attn_ck_fwd: when THD and no bias, branch to small-seq path when
max_seqlen_q==1 and 2<=max_seqlen_kv<=16. On shape query (Aux_CTX_Tensors->size
== 0) skip get_runtime_max_seqlen (cu_seqlens pointers are null); use host
max_seqlen_kv and set output_S to attention-weights shape {max_tokens_q,
h_q, 1, runtime_max_seqlen_kv} and dtype QKV_type. On real run (size >= 2)
call get_runtime_max_seqlen then fused_attn_smallseq_fwd. Use sequence
count b_varlen = max_tokens_q (not segment count b) for get_runtime_max_seqlen,
output_S shape, workspace size, and small-seq fwd so varlen kernel indexing
matches Q and cu_seqlens_kv (THD may pass segment-level cu_seqlens; varlen
kernel expects sequence-level batch).
- In fused_attn_ck_bwd: same THD/small-seq condition. On workspace query
(workspace->data.dptr == nullptr) skip get_runtime_max_seqlen and use host
max_seqlen_kv; on real run call get_runtime_max_seqlen then
fused_attn_smallseq_bwd. Use b_varlen = max_tokens_q_bwd for
get_runtime_max_seqlen, workspace size, and small-seq bwd.
- Reuse softmax LSE auxiliary buffer for attention weights in the small-seq
path (forward write, backward read);
- JAX attention.py: in NVTE_CK block, when THD and q_max_seqlen==1 and
kv_max_seqlen<=16 set softmax_shape = (*batch_shape, attn_heads,
q_max_seqlen, kv_max_seqlen) and softmax_dtype = q_dtype so Python aux
buffer matches C++ attention-weights convention.
- Add test_ck_unfused_smallseq_backend in tests/jax/test_fused_attn.py
(parametrized s_kv in {2,4,6,8,12,16}, b=30720, s_q=1, THD_THD_THD,
SeqDescFormat.Seqlens) and optional NVTE_LOG_CK_SMALLSEQ debug logging in
C++.
…port to small-seq kernels
- tests/jax: CK small-seq tests use fixture to set/restore NVTE_FUSED_ATTN_CK_SMALLSEQ=1; parametrize dtype (BF16/FP16) and add sequence-packing cases (2048-2-4, 2-4096-8192); when env set, num_segments_per_seq = max_seqlen_q for THD else 2. - JAX attention.py: THD softmax shape/dtype uses small-seq path only when env=1, else original layout - JAX attention.cpp: Added env guard - fused_attn_smallseq: Use TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT for fwd/bwd; add FP16 (__half) support; fix __half*float with T(scale).
… backend Refactor the ROCm small-sequence attention path so it is a first-class backend instead of branching from the generic CK fused-attention entry: add NVTE entry points and ROCm implementations under fused_attn_rocm, CMake wiring, and public declarations in fused_attn.h. Rename the small-seq sources to fused_attn_small_seq.* so filenames match the new API. Extend kernel dispatch to head sizes 128, 256, and 512.
Wire the explicit small-sequence path in JAX csrc.
- extensions.h: declare GetSmallSeqAttn{Forward,Backward}WorkspaceSizes,
SmallSeqAttn{Forward,Backward}FFI handlers, and XLA_FFI_DECLARE_HANDLER_SYMBOL
exports for XLA registration.
- attention.cpp (USE_ROCM): add PrepareSmallSeqAttnForwardAuxTensors /
PrepareSmallSeqAttnBackwardAuxTensors to build NVTETensorPack for small-seq
(softmax slot = attention-weights buffer layout, RNG slot per fused API
contract); memset ragged output/softmax aux as needed for THD.
- GetSmallSeqAttnForwardWorkspaceSizes / GetSmallSeqAttnBackwardWorkspaceSizes:
gate on nvte_is_small_seq_attn_supported, return minimal forward workspace and
nvte_fused_attn_small_seq_bwd_workspace_size-backed backward scratch.
- SmallSeqAttnForwardImpl / SmallSeqAttnBackwardImpl: reuse FUSED_ATTN_IMPL_COMMON_BLOCK
for THD cu_seqlens/offsets, call nvte_fused_attn_small_seq_fwd / _bwd.
- SmallSeqAttnForwardFFI / SmallSeqAttnBackwardFFI + XLA_FFI_DEFINE_HANDLER_SYMBOL:
mirror FusedAttn*FFI attribute unpacking so JAX can invoke the dedicated backend.
Add SmallSeqAttnFwdPrimitive / SmallSeqAttnBwdPrimitive in cpp_extensions/attention.py
so JAX compiles and lowers to the dedicated small-seq FFI without
nvte_get_fused_attn_backend or generic fused-attn workspace probing.
- abstract: HIP-only; validate THD_THD_THD, no bias/dropout and head dims; softmax_aux shape (*batch, heads, q, min(kv,16))
in Q dtype; workspace from get_small_seq_attn_{fwd,bwd}_workspace_sizes.
- lowering: ffi_lowering to te_small_seq_attn_{forward,backward}_ffi with the
same flattened attrs pattern as generic fused attention.
- fused_attn_small_seq_fwd / fused_attn_small_seq_bwd: thin bind helpers;
export via __all__. register_primitive for both primitives.
Expose ROCm small-sequence cross-attention at the JAX layer next to fused_attn. - Custom primitive _fused_attn_small_seq with forward/backward rules calling cpp_extensions fused_attn_small_seq_fwd/bwd (tex.*). - fused_attn_small_seq(): user entry point taking (q,k,v), bias slot, SequenceDescriptor, seed, mask/layout/scaling/dropout/is_training — targets the explicit small-seq backend.
…smallseq-cross-attn-new-backend
8ace430 to
09ab963
Compare
| "the F16_arbitrary_seqlen backend." | ||
| ) | ||
|
|
||
| def _setup_thd_segments_small_seq(self, generate_random_segment_ids): |
There was a problem hiding this comment.
Since our customer already mentioned that their sq is 1 for each segment, without padding at all, and their sq_kv <=16 including padding, we can create our own small_seq input generation to separate with the non-small-seq tests
| const Tensor *input_cu_seqlens_kv = convertNVTETensorCheck(cu_seqlens_kv); | ||
| const Tensor *input_cu_seqlens_q_padded = convertNVTETensorCheck(cu_seqlens_q_padded); | ||
| const Tensor *input_cu_seqlens_kv_padded = convertNVTETensorCheck(cu_seqlens_kv_padded); | ||
| const Tensor *input_rng_state = convertNVTETensorCheck(rng_state); |
There was a problem hiding this comment.
Do we even support dropout with this rng_state?
| workspace_bytes *= fused_attn_rocm::nvte_dtype_size(wkspace->data.dtype); | ||
| NVTE_CHECK(workspace_bytes >= req_bytes, "nvte_fused_attn_small_seq_bwd: workspace too small."); |
There was a problem hiding this comment.
nit: now we don't have mixed old ck + new small_seq flow, will we still get more workspace than we need?
There was a problem hiding this comment.
nit: can we have their repo as a 3rdparty then reference their .h. No need to do this now
| size_t fused_attn_small_seq_bwd_workspace_size(size_t b, | ||
| size_t h_q, | ||
| size_t max_seqlen_kv, | ||
| DType dtype) { | ||
| constexpr size_t elt_size = 2u; // BF16 and FP16 are 2 bytes | ||
| return b * h_q * 1 * std::min(max_seqlen_kv, size_t(16)) * elt_size; | ||
| } |
There was a problem hiding this comment.
Let's have a comment on what's the actual bwd workspace for, namely, why it's bhq1std::min(max_seqlen_kv, size_t(16))*elt_size?
Probably it's because it's for dS or dP?
| del config, result_infos | ||
| q_spec = get_padded_spec(arg_infos[0]) | ||
| out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) | ||
| softmax_aux_sharding = NamedSharding( |
There was a problem hiding this comment.
Not sure how sharding rule need to be changed
| # NVTE uses b = cu_seqlens_q.shape[0] - 1 (one packed segment per slot), not | ||
| # reduce(batch_shape). E.g. seqpack with max_seqlen_q>1 yields cu length | ||
| # batch*segments+1 while Q still has leading logical batch only. | ||
| small_seq_workspace_batch = q_seqlen_or_cu_seqlen_aval.shape[0] - 1 |
There was a problem hiding this comment.
I think at jax side, we don't need to read the actual value of cu_seqlen's.
| NVTE_CHECK(bias_batch == 0 && bias_heads == 0, | ||
| "SmallSeqAttnForwardImpl: bias not supported for small-seq."); | ||
|
|
||
| auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); |
| PrepareSmallSeqAttnForwardAuxTensors(&aux_output_tensors, input_batch, attn_heads, q_max_seqlen, | ||
| kv_max_seqlen, dtype, softmax_aux, rng_state); | ||
|
|
||
| auto dummy_page_table_tensor = TensorWrapper(nullptr, std::vector<size_t>{1}, DType::kInt32); |
There was a problem hiding this comment.
If we know it's dummy and our nvte_fused_attn_small_seq_fwd do not use it actually, we can skip the passing
There was a problem hiding this comment.
Same as other tensors not used, for example rng_state
|
|
||
| auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); | ||
|
|
||
| if (is_ragged) { |
There was a problem hiding this comment.
Currently we only support is_ragged. Later we will support BSHD. Please add a TODO comment. Otherwise, we should include all following things into this if branch
|
This branch is currently missing commits |
Micky774
left a comment
There was a problem hiding this comment.
Currently, if I'm understanding this correctly, the only way to utilize the new backend is via the Python-level fused_attn_small_seq entrypoint -- is this desired? We have no means of automatically dispatching to this backend through existing API.
| const size_t runtime_s_q = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream)); | ||
| const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream)); |
There was a problem hiding this comment.
This calls ck_fused_attn::get_runtime_max_seqlen unconditionally and breaks AOTriton-only builds.
There was a problem hiding this comment.
Currently the code in this file unconditionally pulls in tex.fused_attn_small_seq_{f,b}wd -- we should guard these funcs.
| outer_primitive = None | ||
|
|
||
| @staticmethod | ||
| def abstract( |
There was a problem hiding this comment.
Can we also proactively guard against GQA/MQA here?
| ) | ||
|
|
||
| @staticmethod | ||
| def impl( |
There was a problem hiding this comment.
_fix_len_take, convert_to_2d, seqlen/offset processing are all copied between FusedAttn*Primitive.impl and SmallSeqAttn*Primitive.impl -- let's try to reuse them.
| void PrepareSmallSeqAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, const size_t input_batch, | ||
| const size_t attn_heads, const size_t q_max_seqlen, | ||
| const size_t kv_max_seqlen, DType dtype, void *softmax_buf, | ||
| void *rng_state_buf) { | ||
| PrepareSmallSeqAttnForwardAuxTensors(tensor_pack, input_batch, attn_heads, q_max_seqlen, | ||
| kv_max_seqlen, dtype, softmax_buf, rng_state_buf); | ||
| } |
There was a problem hiding this comment.
This wrapper is trivial, can we forward directly?
| b, dev_ptr_cu_seqlens_q, nullptr, workspace, stream)); | ||
| const size_t runtime_s_kv = static_cast<size_t>(ck_fused_attn::get_runtime_max_seqlen( | ||
| b, dev_ptr_cu_seqlens_kv, nullptr, workspace, stream)); | ||
| if (const char *env_ck = std::getenv("NVTE_LOG_CK_CONFIG"); |
There was a problem hiding this comment.
At this API level we use NVTE_LOG_FUSED_ATTN_CONFIG
Yes @Micky774, this was intentional. Those missing commits are relevant to some of the corner cases which we decided to not cover. |
We cannot guarantee two conditions i.e., to determine max_seq_len_q = 1 and 2<=max_seq_len_kv<=16, which are known at runtime. i.e., these two tests , even though these two donot satisfy the conditions but they will be satisfied on runtime, this is what is blocking us from automatically deploying to this backend. that being said we will throw assertion errors during runtime if these are not met. |
If we're providing a static entry point anyways, I suspect this means that we expect users to know ahead of time whether their data complies with such a backend. In which case, even mediating with an environment variable should be feasible right? I'm just trying to understand the actual user experience we're trying to support here. |
|
Alternatively, even just adding |
Previously I discussed this with Veera and also our customers. There are two three things prevent us using a unified api as previous ck flow: So any good ideas you have can work those around? |
… backend
Refactor the ROCm small-sequence attention path so it is a first-class backend instead of branching from the generic CK fused-attention entry: add NVTE entry points and ROCm implementations under fused_attn_rocm, CMake wiring, and public declarations in fused_attn.h.
Rename the small-seq sources to fused_attn_small_seq.* so filenames match the new API. Extend kernel dispatch to head sizes 128, 256, and 512.
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: