Skip to content

[PyTorch/Common] Remove legacy FP8DS implementation #2959

Merged
cyanguwa merged 11 commits intoNVIDIA:mainfrom
cyanguwa:remove_fp8_v0
May 7, 2026
Merged

[PyTorch/Common] Remove legacy FP8DS implementation #2959
cyanguwa merged 11 commits intoNVIDIA:mainfrom
cyanguwa:remove_fp8_v0

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented May 5, 2026

Description

This PR removes a legacy path of FP8 Delayed Scaling implementation from TE 1.6.0. It supports T3HD with max_seq_len<=512, head_dim=64, and padding mask. cudnn-frontend will remove their pre-FORT hand-written FMHA kernels (MR2829) hence the removal of this FP8 implementation here. General THD support for FP8 will be added in future PRs.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cyanguwa and others added 3 commits April 30, 2026 16:50
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review May 5, 2026 19:31
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 5, 2026

Greptile Summary

This PR removes the legacy cuDNN v0 FP8 Delayed Scaling attention implementation that supported only T3HD layout with max_seq_len≤512, head_dim=64, and padding mask. The canonical cuDNN frontend v1.0+ implementation (fused_attn_fp8_fwd/bwd_impl, formerly _v1) is retained and becomes the sole FP8 path.

  • ~1,630 lines of v0 CUDA graph code deleted from fused_attn_fp8.cu (hand-written createAmax, createScale, softmax/dropout helpers, QK/SV BMM nodes, cu_seqlens_to_offsets kernel) and the corresponding T3HD dispatch branches removed from fused_attn.cpp.
  • ZInv tensor eliminated end-to-end: removed from C++ function signatures (fused_attn_fp8_bwd), aux-tensor allocation in attention.cpp, aux-tensor unpacking in context_parallel.py, and all public API docs.
  • Tests simplified: cuDNN ≥9.2.1 is the sole version gate; all 8 model_configs_fp8 entries now run on the bs3hd path; t3hd-layout test branches removed.

Confidence Score: 5/5

Safe to merge — the removal is complete and internally consistent across all layers (CUDA kernel, C++ dispatch, Python bindings, context-parallel helpers, and tests).

All ZInv references have been eliminated from every layer of the stack (verified via grep). The aux-tensor index arithmetic in fwd and bwd is consistent with the new two-element layout [S, rng_state]. The T3HD dispatch branch is cleanly removed from the backend selector and both fwd/bwd call sites. The renamed functions (dropping the _v1 suffix) match their new declarations and callers. No orphaned declarations or dangling includes remain.

No files require special attention. The one stale docstring in test_attention.py is cosmetic only.

Important Files Changed

Filename Overview
tests/pytorch/attention/test_attention.py Removes v0-specific parametrization (t3hd layout, cuDNN >=8.9.3 skip guard, models_v0/v1 variables); all 8 fp8 models now tested with bs3hd and cuDNN >=9.2.1. Docstring still references removed v0.9 path.
transformer_engine/common/fused_attn/fused_attn.cpp Removes the cuDNN 8.9 T3HD FP8 sub-condition from backend selection and the corresponding dispatch branches in fwd/bwd; removes ZInv from aux-tensor indexing in bwd; renames internal v1 impl calls to the now-canonical names.
transformer_engine/common/fused_attn/fused_attn_fp8.cu Removes ~1,630 lines of hand-written cuDNN v0 graph code (createAmax, createScale, softmax forward/backward, QK/SV BMM helpers, etc.) and renames fused_attn_fp8_fwd_impl_v1/bwd_impl_v1 to the canonical names; devPtrZInv removed from both signatures.
transformer_engine/common/fused_attn/fused_attn_fp8.h Header declaration for fused_attn_fp8_bwd updated to remove the input_ZInv parameter; file-level comment updated to drop the seqlen <= 512 restriction.
transformer_engine/common/fused_attn/utils.cu Removes the cu_seqlens_to_offsets CUDA kernel that was only used by the removed v0 bwd path.
transformer_engine/common/fused_attn/utils.h Removes the cu_seqlens_to_offsets kernel declaration, consistent with the utils.cu removal.
transformer_engine/common/include/transformer_engine/fused_attn.h Removes the support matrix table for the T3HD FP8 backend from public API docs and updates Aux_CTX_Tensors description to remove ZInv references.
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Removes t3hd-specific aux_ctx_tensors unpacking in cp_p2p_fwd_fused_attn, cp_p2p_bwd_fused_attn, and AttnFuncWithCPAndKVAllGather; FP8 aux tensors are now uniformly [softmax_lse, rng_state].
transformer_engine/pytorch/cpp_extensions/fused_attn.py Updates docstrings to reflect the new FP8 aux-tensor layout (S instead of M/ZInv) and removes references to T3HD and ZInv.
transformer_engine/pytorch/csrc/extensions/attention.cpp Removes the T3HD condition that previously allocated an extra ZInv tensor; FP8 aux pack now only gets S and rng_state, consistent with the backend changes.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["nvte_fused_attn_fwd / bwd"] --> B{dtype}
    B -->|FP8| C{qkv_format}
    B -->|F16/BF16| D[F16 arbitrary-seqlen backend]
    C -->|BSHD / SBHD / BHSD| E["fused_attn_fp8_fwd_impl\n(cuDNN FE v1.0+)"]
    C -->|T3HD removed| F["REMOVED: fused_attn_fp8_fwd_impl v0\n(cuDNN 8.9, seqlen<=512, d=64)"]
    C -->|other| G[NVTE_ERROR]
    E --> H["Aux tensors: S, rng_state"]
    F -.->|"was: S, ZInv, rng_state"| H
Loading

Reviews (7): Last reviewed commit: "Merge branch 'main' into remove_fp8_v0" | Re-trigger Greptile

cyanguwa and others added 5 commits May 5, 2026 13:37
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa changed the title [PyTorch/Common] Remove old, unused FP8 implementation [PyTorch/Common] Remove legacy FP8DS implementation May 5, 2026
cyanguwa added 2 commits May 5, 2026 15:11
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa added the 2.16.0 label May 5, 2026
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 5, 2026

/te-ci L0

@cyanguwa cyanguwa requested a review from sudhakarsingh27 May 5, 2026 22:21
@cyanguwa cyanguwa mentioned this pull request May 7, 2026
13 tasks
Comment thread transformer_engine/common/fused_attn/fused_attn.cpp
Comment thread transformer_engine/common/fused_attn/fused_attn.cpp
Comment thread transformer_engine/common/include/transformer_engine/fused_attn.h
@cyanguwa cyanguwa requested a review from sudhakarsingh27 May 7, 2026 23:47
sudhakarsingh27
sudhakarsingh27 previously approved these changes May 7, 2026
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@cyanguwa cyanguwa merged commit e8c0dc6 into NVIDIA:main May 7, 2026
10 of 12 checks passed
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 8, 2026

/te-ci pytorch L0

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants