Skip to content

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971

Draft
pggPL wants to merge 15 commits intoNVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper
Draft

[common] Grouped gemm update - nvfp4 for blackwell and fp8 blockwise hopper#2971
pggPL wants to merge 15 commits intoNVIDIA:mainfrom
pggPL:grouped_gemm_nvfp4_and_hopper

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 8, 2026

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 15 commits March 16, 2026 11:36
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Use existing nvte_set_grouped_tensor_param with kNVTEGroupedWithGEMMSwizzledScales
instead of the dedicated set/get functions.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
- Add CUBLAS_NVFP4_GROUPED_GEMM_VERSION and CUBLAS_FP8_BLOCK_GROUPED_GEMM_VERSION macros (13.4+)
- Update check_grouped_gemm_requirements to allow SM90 with cuBLAS 13.4+
- Refactor execute_grouped_gemm to use GroupedGemmConfig struct
- Add divisibility-by-128 validation for FP8 block scaling in setup kernel and quantizer
- Support scalar alpha/beta for Hopper (no per-group alpha/beta)
- Expose get_grouped_gemm_setup_workspace_size to PyTorch via pybind
- Update PyTorch tests to run grouped GEMM on Hopper with cuBLAS 13.4+

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
… scaling tests on Hopper

Extend nvte_grouped_gemm_with_discrete_inputA to handle NVFP4 (Float4E2M1)
inputs: accept kFloat4E2M1 dtype, propagate scale_inv pointers, collect
contiguous amax from discrete tensors, and enforce swizzled-scales checks
for NVFP4 alongside MXFP8. Also add GTEST_SKIP for FP8 tensor scaling
grouped GEMM on Hopper since cuBLAS does not support it there.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…M tests

The setup kernel computes per-tensor scale pointers as data_offset /
block_size, which assumes no padding in the scale buffer. This is only
correct when first_dim % 128 == 0 and last_dim % 128 == 0 (MXFP8) or
last_dim % 64 == 0 (NVFP4). Add explicit assertions in
build_grouped_tensor to catch any future test shapes that violate this.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…d_hopper

Conflicts resolved (3 files):

* tests/pytorch/test_numerics.py
  test_grouped_gemm_grouped_tensor: combined skip rules — Hopper (SM90) requires
  cuBLAS 13.4+, Blackwell+ (SM100) requires cuBLAS 13.3+. Kept main's
  use_bias_scale parametrization.

* transformer_engine/pytorch/cpp_extensions/gemm.py
  general_grouped_gemm_for_grouped_tensor: combined HEAD's num_alphabeta logic
  (single scalar on Hopper, per-group on Blackwell+) with main's cached
  _get_fp32_ones_tensor / _get_fp32_zeros_tensor helpers.

* transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
  - validate_grouped_gemm_inputs: kept HEAD's NVFP4 / FP8 block-scaling
    consistency checks, wrapped in main's nullptr-guard / continue-on-no-data
    pattern.
  - GroupedGemmConfig struct retained; added sm_count from main and
    propagated config_.sm_count -> gemm_config.sm_count in all three
    public APIs.
  - kMaxTensorsPerKernel rename to kMaxGroups (= 64) adopted from main.
  - execute_grouped_gemm signature uses GroupedGemmConfig (HEAD); body uses
    config.sm_count for CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET (from main).
  - Dropped HEAD's simple grouped_bias_add_kernel (dead code); kept main's
    advanced grouped_bias_add_kernel + find_tensor_for_row helper.
  - Replaced inline SM/cuBLAS preambles with check_grouped_gemm_requirements()
    calls in nvte_grouped_gemm, nvte_grouped_gemm_with_discrete_inputA, and
    nvte_grouped_gemm_with_discrete_out. The helper supports both
    Hopper (SM90 + cuBLAS 13.4+) and Blackwell+ (SM100 + cuBLAS 13.3+).
  - Kept HEAD's validate_grouped_gemm_inputs(..., use_per_group_alpha_beta)
    signature for proper alpha/beta validation across architectures.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…or swizzle tests

cublaslt_grouped_gemm.cu:
- Fix incorrect handling of NVFP4/MXFP8 columnwise data in
  build_grouped_gemm_multi_inputA_args by adding a swap_dims flag
  consistent with choose_grouped_operand_storage. Use A_sel.trans
  (post-flip) for gemm_config.avg_k so K is selected from the
  correct dim with discrete A_list.

tests/cpp/test_common.{h,cu}:
- Add enforce_grouped_gemm_alignment parameter (default true) to
  build_grouped_tensor; the MXFP8/NVFP4 first/last_dim 128/64
  alignment asserts are only relevant for the grouped GEMM setup
  kernel, so callers that bypass it (swizzle/unswizzle) opt out.

tests/cpp/operator/test_swizzle.cu:
- Pass enforce_grouped_gemm_alignment=false to build_grouped_tensor
  in MXFP8 swizzle/unswizzle/roundtrip tests, which intentionally
  exercise non-padded shapes.

tests/cpp/operator/test_grouped_gemm.cu:
- Sync GPU/cuBLAS skip rules across all 3 sub-tests, add
  cudaDeviceSynchronize() after nvte_multi_tensor_gemm reference for
  defensive sync, and skip NVFP4 + AllDifferent in all 3 sub-tests
  due to a known flaky bug in the nvte_multi_tensor_gemm reference.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Made-with: Cursor
…and_hopper

# Conflicts:
#	transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant