Disable the RHT fusion for non-SM100 family devices#2968
Disable the RHT fusion for non-SM100 family devices#2968ptrendx wants to merge 2 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR fixes a crash/failure on SM120 (RTX 50-series, GB20x) devices by restricting the RHT cast fusion kernel to SM100–SM110 range architectures, where the required shared memory and MMA hardware are available.
Confidence Score: 4/5Safe to merge; the change is a conservative guard that falls back to the non-fused path on excluded architectures, so correctness is maintained even if the boundary is slightly off. The fix is minimal and well-targeted. The only open question is whether <= 110 or < 120 is the right upper bound — if an SM111/SM112 Blackwell compute variant exists, the fusion would be skipped there, costing performance but not correctness. transformer_engine/pytorch/csrc/quantizer.cpp — specifically the upper-bound value in the arch range check. Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[quantize_impl called] --> B{input dtype == BFloat16\nrows % 64 == 0\ncols % 128 == 0}
B -- No --> D[eligible_for_rht_cast_fusion = false]
B -- Yes --> C{sm_arch >= 100\nAND sm_arch <= 110}
C -- Yes\nSM100-SM110\nBlackwell compute --> E[eligible_for_rht_cast_fusion = true]
C -- No\nSM120 RTX 50xx\nor other arch --> D
E --> F[Use fused RHT + cast kernel]
D --> G[Use separate RHT and cast paths]
Reviews (1): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| bool eligible_for_rht_cast_fusion = | ||
| input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; | ||
| input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0 && | ||
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; |
There was a problem hiding this comment.
The upper bound
<= 110 is tighter than the stated intent ("non-SM100 family"). Using < 120 more precisely captures "anything below SM120" and avoids silently disabling the fusion for hypothetical SM111/SM112 variants that belong to the same Blackwell compute family. The codebase already uses 120 as the implicit dividing line (SM120 = GB20x, which is the architecture that triggered the bug), so < 120 reads as clearly intentional.
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() <= 110; | |
| transformer_engine::cuda::sm_arch() >= 100 && transformer_engine::cuda::sm_arch() < 120; |
Description
Disable the RHT fusion for non-sm100 class devices (the kernel uses too much shared memory to be runnable on e.g. sm120).
Fixes #2956
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: