Skip to content

Disable the RHT fusion for non-SM100 family devices#2968

Open
ptrendx wants to merge 2 commits intoNVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion
Open

Disable the RHT fusion for non-SM100 family devices#2968
ptrendx wants to merge 2 commits intoNVIDIA:mainfrom
ptrendx:pr_fix_rht_fusion

Conversation

@ptrendx
Copy link
Copy Markdown
Member

@ptrendx ptrendx commented May 8, 2026

Description

Disable the RHT fusion for non-sm100 class devices (the kernel uses too much shared memory to be runnable on e.g. sm120).

Fixes #2956

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Add the check on the sm arch when testing for the fusion eligibility.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
@ptrendx ptrendx requested a review from timmoon10 May 8, 2026 00:07
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR fixes a crash/failure on SM120 (RTX 50-series, GB20x) devices by restricting the RHT cast fusion kernel to SM100–SM110 range architectures, where the required shared memory and MMA hardware are available.

  • Adds sm_arch() >= 100 && sm_arch() <= 110 to the eligible_for_rht_cast_fusion guard in quantizer.cpp, preventing the over-subscribed shared-memory kernel from being dispatched on SM120 and newer consumer Blackwell devices.
  • All other quantization paths remain unaffected; only the fused RHT+cast path is gated.

Confidence Score: 4/5

Safe to merge; the change is a conservative guard that falls back to the non-fused path on excluded architectures, so correctness is maintained even if the boundary is slightly off.

The fix is minimal and well-targeted. The only open question is whether <= 110 or < 120 is the right upper bound — if an SM111/SM112 Blackwell compute variant exists, the fusion would be skipped there, costing performance but not correctness.

transformer_engine/pytorch/csrc/quantizer.cpp — specifically the upper-bound value in the arch range check.

Important Files Changed

Filename Overview
transformer_engine/pytorch/csrc/quantizer.cpp Adds SM-architecture range check (>= 100 && <= 110) to gate RHT cast fusion eligibility; the upper bound of 110 may be slightly narrower than intended.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[quantize_impl called] --> B{input dtype == BFloat16\nrows % 64 == 0\ncols % 128 == 0}
    B -- No --> D[eligible_for_rht_cast_fusion = false]
    B -- Yes --> C{sm_arch >= 100\nAND sm_arch <= 110}
    C -- Yes\nSM100-SM110\nBlackwell compute --> E[eligible_for_rht_cast_fusion = true]
    C -- No\nSM120 RTX 50xx\nor other arch --> D
    E --> F[Use fused RHT + cast kernel]
    D --> G[Use separate RHT and cast paths]
Loading

Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

bool eligible_for_rht_cast_fusion =
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0;
input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 &&
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 The upper bound <= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.

Suggested change
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110;
transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120;

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

2 participants