Skip to content

[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970

Open
KshitijLakhani wants to merge 8 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/feat/attn-swa-enh-jax
Open

[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970
KshitijLakhani wants to merge 8 commits intoNVIDIA:mainfrom
KshitijLakhani:klakhani/feat/attn-swa-enh-jax

Conversation

@KshitijLakhani
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani commented May 8, 2026

Description

TE PyT attn uses check_set_window_size() to regulate the window_size based on the attn mask and the user passed window_size. This is done higher in the stack, so that a limited subset of "valid" values of the window_size propagate to the backends.

Via this PR TE JAX attn tries to mimic this behavior to clean up the SWA mechanism in the different TE JAX attn APIs via a common updating logic. This new function check_set_window_size() does not constrain the user of the API (rather strips this complexity off for the user) when using (or not using SWA). TE handles the checks and modifications to the window_size internally and warns the user about any canonicalization performed, when needed.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Update APIs to canonicalize the SWA before passing to the backends

  • Update internal CP P2P Ring helpers to re-canonicalize when it changes the mask for internal computations

  • Update tests to validate the chosen SWA window before running the tests

  • nit: [PyT] Fixed a small QOL check in the right side window for causal cases so as to not actively coerce that value and instead assert it.

Testing

Ran local fused attn + dist fused attn tests on H1008 and GB2004 and they pass successfully.

Next steps

A subsequent PR will fix any issues / missing op in Unfused DPA SWA to make the SWA infrastructure and fallback paths more robust

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

- Verifies if a window is correct for a given mask type. If it isn't either force sentinel values or assert. If forcing sentinel values then warn the user
- All possible ways of using attn, i.e. DPA, MHA, TL, fused attn APIs are all now guaranteeing that window size will not be None and appropriately set before passing downstream to internal APIs, primitives or classes.

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…tract responsibility can be handled by MHA and lower APIs

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…fused attn

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani self-assigned this May 8, 2026
@KshitijLakhani KshitijLakhani added enhancement New feature or request jax attention labels May 8, 2026
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci jax L0 L1

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani force-pushed the klakhani/feat/attn-swa-enh-jax branch from 2c5a448 to c770934 Compare May 8, 2026 18:28
@KshitijLakhani KshitijLakhani marked this pull request as ready for review May 8, 2026 18:33
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR introduces a check_set_window_size() canonicalization function in the JAX attention path that mirrors the existing PyTorch helper. It enforces that every call site — module constructors (DotProductAttention, MultiHeadAttention, _FusedDotProductAttention, _UnfusedDotProductAttention) and the cpp-extension boundary (fused_attn_fwd/bwd) — receives a properly encoded (left, right) sentinel before the value is forwarded to cuDNN. A corresponding one-line fix is applied to the PyTorch version to prevent negative right-side values from being silently coerced.

  • New check_set_window_size in JAX: canonicalizes None or cross-compatible sentinel values ((-1,-1)(-1,0)) with per-mask-type rules, raising on genuinely invalid inputs and emitting UserWarning for coercible but unexpected values.
  • Ring-CP get_step_config: re-canonicalizes window_size against the per-step attn_mask_type override using warn=False to avoid confusing user-visible warnings.
  • Fast-path sentinel in _segment_ids_pos_to_seqlens_offsets: old window_size is None / == (-1,-1) checks unified to window_size[0] == -1.

Confidence Score: 5/5

Safe to merge — canonicalization is applied consistently at every public entry point and the cpp-extension boundary, double-calls are idempotent, and the PyTorch negative-right-value coerce bug is correctly fixed.

The new check_set_window_size branches cover all sentinel combinations correctly, including the previously problematic negative-right-value inputs that are now rejected rather than silently coerced. All module constructors apply the function before the module is frozen, and subsequent calls at lower layers see only already-canonical values, so no spurious warnings will fire in normal usage.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/jax/attention.py Adds check_set_window_size canonicalizer with correct branch logic for causal/non-causal masks; updates _segment_ids_pos_to_seqlens_offsets fast-path sentinel; type-annotates several internal helpers. No logic errors found.
transformer_engine/jax/cpp_extensions/attention.py Applies check_set_window_size at fused_attn_fwd/bwd boundary and in get_step_config (warn=False for internal ring-CP mask switch); updates SWA-presence sentinel from != (-1,-1) to [0] != -1. Changes are idempotent on already-canonical input.
transformer_engine/jax/flax/transformer.py Adds post_init canonicalization to _UnfusedDotProductAttention, _FusedDotProductAttention, DotProductAttention, and MultiHeadAttention; TransformerLayer deliberately skips canonicalization and delegates to inner MHA.
transformer_engine/pytorch/attention/dot_product_attention/utils.py One-line fix: changes orig_window_size[1] != 0 to > 0 in the causal coerce branch, preventing negative right values from being silently coerced instead of rejected.
tests/jax/test_fused_attn.py _get_swa_window_size_for_test now routes its candidate window through check_set_window_size to ensure tests exercise values consistent with production canonicalization.

Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

Comment thread transformer_engine/jax/attention.py
Comment thread transformer_engine/jax/attention.py Outdated
KshitijLakhani and others added 3 commits May 8, 2026 14:16
…cing

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…cing for PyTorch framework code

Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
@KshitijLakhani KshitijLakhani changed the title [JAX] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends [JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends May 8, 2026
@KshitijLakhani
Copy link
Copy Markdown
Collaborator Author

/te-ci L0 L1

@KshitijLakhani KshitijLakhani requested a review from cyanguwa May 8, 2026 22:00
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

attention enhancement New feature or request jax

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant