[PyTorch] Enable head dim 256 for FA4#2932
Conversation
bdcc02e to
3b3f7d0
Compare
Greptile SummaryThis PR enables head_dim=256 support in FlashAttention 4 by delegating head-dimension validation to FA4's own
Confidence Score: 4/5Safe to merge once the The import of transformer_engine/pytorch/attention/dot_product_attention/backends.py — the Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[backends.py module load] --> B{FA4 package installed?}
B -- No --> C[flash_attn_func_v4 = None]
B -- Yes --> D[import flash_attn_func, flash_attn_varlen_func,\n_validate_head_dims]
D -- ImportError if symbol missing --> E[Unhandled ImportError breaks backends.py load]
D -- OK --> F[v4_validate_head_dims = _fa4_validate_head_dims]
F --> G[get_attention_backend called]
G --> H{use_flash_attention_4 and v4_validate_head_dims != None?}
H -- No --> I[Skip FA4 head-dim validation]
H -- Yes --> J[Call v4_validate_head_dims]
J -- AssertionError --> K[use_flash_attention_4 = False]
J -- OK --> L{SM100 MLA workaround needed?}
L -- Yes misaligned --> M[use_flash_attention_4 = False]
L -- No --> N[FA4 selected]
|
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch L3 |
|
@vcherepanov-nv @KshitijLakhani Please review. |
| not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." | ||
| ) | ||
| @pytest.mark.skipif( | ||
| get_device_compute_capability() != (10, 0), |
There was a problem hiding this comment.
Is it supported on b300 (sm103)?
| @pytest.mark.skipif( | ||
| not FlashAttentionUtils.v4_is_installed, reason="Flash-attn v4 (flash-attn-4) is required." | ||
| ) | ||
| @pytest.mark.skipif(get_cudnn_version() < (8, 9, 1), reason="cuDNN 8.9.1+ is required.") |
There was a problem hiding this comment.
Should we bump this instead of removing?
| # dV TMEM load atoms. When (tile_hdimv // 2) % dK_reduce_ncol != 0, dV reads are | ||
| # misaligned. The dedicated (256, 256) kernel uses its own tmem layout so it's | ||
| # not affected. See: flash_attn/cute/flash_bwd_sm100.py, line ~262 and ~3890. | ||
| if ( |
There was a problem hiding this comment.
Should this still be checked when FlashAttentionUtils.v4_validate_head_dims == None?
|
LGTM |
Description
Need FA4 version
4.0.0b11.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: