[JAX] [PyT] Tighten SWA checks in DPA, MHA and other APIs before passing onto cuDNN fused attn & unfused attn backends#2970
Conversation
- Verifies if a window is correct for a given mask type. If it isn't either force sentinel values or assert. If forcing sentinel values then warn the user - All possible ways of using attn, i.e. DPA, MHA, TL, fused attn APIs are all now guaranteeing that window size will not be None and appropriately set before passing downstream to internal APIs, primitives or classes. Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…tract responsibility can be handled by MHA and lower APIs Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…fused attn Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
|
/te-ci jax L0 L1 |
Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
2c5a448 to
c770934
Compare
for more information, see https://pre-commit.ci
Greptile SummaryThis PR introduces a
Confidence Score: 5/5Safe to merge — canonicalization is applied consistently at every public entry point and the cpp-extension boundary, double-calls are idempotent, and the PyTorch negative-right-value coerce bug is correctly fixed. The new check_set_window_size branches cover all sentinel combinations correctly, including the previously problematic negative-right-value inputs that are now rejected rather than silently coerced. All module constructors apply the function before the module is frozen, and subsequent calls at lower layers see only already-canonical values, so no spurious warnings will fire in normal usage. No files require special attention. Important Files Changed
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
…cing Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
…cing for PyTorch framework code Signed-off-by: Kshitij Lakhani <klakhani@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci L0 L1 |
Description
TE PyT attn uses
check_set_window_size()to regulate thewindow_sizebased on theattn maskand the user passedwindow_size. This is done higher in the stack, so that a limited subset of "valid" values of thewindow_sizepropagate to the backends.Via this PR TE JAX attn tries to mimic this behavior to clean up the SWA mechanism in the different TE JAX attn APIs via a common updating logic. This new function
check_set_window_size()does not constrain the user of the API (rather strips this complexity off for the user) when using (or not using SWA). TE handles the checks and modifications to thewindow_sizeinternally and warns the user about any canonicalization performed, when needed.Type of change
Changes
Update APIs to canonicalize the SWA before passing to the backends
Update internal CP P2P Ring helpers to re-canonicalize when it changes the mask for internal computations
Update tests to validate the chosen SWA window before running the tests
nit: [PyT] Fixed a small QOL check in the right side window for causal cases so as to not actively coerce that value and instead assert it.
Testing
Ran local fused attn + dist fused attn tests on H1008 and GB2004 and they pass successfully.
Next steps
A subsequent PR will fix any issues / missing op in Unfused DPA SWA to make the SWA infrastructure and fallback paths more robust
Checklist: