Skip to content

Implement 4over6 NVFP4 recipe#2972

Draft
zianglih wants to merge 20 commits intoNVIDIA:mainfrom
zianglih:4over6
Draft

Implement 4over6 NVFP4 recipe#2972
zianglih wants to merge 20 commits intoNVIDIA:mainfrom
zianglih:4over6

Conversation

@zianglih
Copy link
Copy Markdown
Contributor

@zianglih zianglih commented May 9, 2026

Description

@HumansAnd

Implement 4over6 nvfp4 from:

FlashInfer PR:

Enable per-block map-to-4 versus map-to-6 candidate selection for NVFP4 1D quantization in the NVFP4BlockScaling recipe. This mode currently requires RHT, stochastic rounding, and 2D quantization to be disabled. Both original per-tensor scaling and row-scaling NVFP4 introduced by #2931 are supported.

This PR also fixes a few minor bugs for row-scaled NVFP4 from #2931.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Ziang Li <ziangli@umich.edu>
@zianglih zianglih marked this pull request as draft May 9, 2026 03:50
@zianglih zianglih changed the title Implement 4over6 nvfp4 Implement 4over6 nvfp4 recipe May 9, 2026
@zianglih zianglih changed the title Implement 4over6 nvfp4 recipe Implement 4over6 NVFP4 recipe May 9, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 9, 2026

Greptile Summary

This PR implements 4over6 block-scale selection for NVFP4 1D quantization: for each 16-element block the kernel computes two scale candidates (map-to-4 and map-to-6), quantizes with both, and picks the one with lower MSE, using a reduced global-scale ceiling of 256 instead of 448 to give the map-4 branch room to represent larger blocks.

  • Adds USE_4OVER6 / kUse4Over6 template parameters to both CUDA quantization kernels (quantize_transpose_nvfp4_tuned_1D.cuh and quantize_transpose_vector_blockwise_fp4.cu) and doubles the compile-time dispatch switch depth.
  • Threads use_4over6 through the full Python → C++ → CUDA stack (recipe, quantizer, tensor storage, C++ extensions) with guards disabling incompatible modes (RHT, stochastic rounding, 2D quantization, grouped quantization).
  • Ships a matching reference Python implementation in NVFP4QuantizerRef and adds parametrised test coverage for the new mode.

Confidence Score: 4/5

The 4over6 logic is well-guarded with incompatible-mode checks at every entry point; the most notable gap is that NVTE_USE_FAST_MATH is wired only into the split-quantization path and has no effect on the standard single-tensor path.

The core quantization math and scale-selection logic look correct and is backed by a reference implementation with tests. The main concern is the NVTE_USE_FAST_MATH env var being silently ignored for single-tensor 4over6 quantization, and the hardcoded [2] array sizes in the CUDA kernel that rely on an implicit loop-bound assumption without a static assert.

transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu (hardcoded array sizing) and transformer_engine/pytorch/csrc/extensions/cast.cpp (fast-math env var only applied in the split path).

Important Files Changed

Filename Overview
transformer_engine/common/cast/nvfp4/core_nvfp4.cuh Adds USE_4OVER6 template param to global scale computation and introduces compute_4over6_decoding_scaling_factors + cvt_fp32_to_fp4_8x_with_mse_rn (Blackwell-only PTX). Core 4over6 math looks correct; includes both fast-math and precise-rounding variants.
transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh Extends rowwise_scaling and colwise_scaling device functions with USE_4OVER6 template parameter that computes both map4/map6 candidates and picks the lower-MSE result; kernel template parameter list gains USE_4OVER6 and doubles the switch nesting for dispatch.
transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu Adds kUse4Over6 and kUseFastMath template params, inlines the MSE-based candidate selection for both rowwise and transpose paths; hardcoded output_vec_map4[2] / output_vec_map6[2] arrays assume loop bound equals 8 without a static assert.
transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py Reference quantizer gains use_4over6 flag and a full MSE-based block-scale selection path in _quantize_tiles; 4over6 returns early, correctly bypassing the standard encode_scale path. Validation guards against incompatible combos (pow2, 2D, RHT).
transformer_engine/pytorch/csrc/extensions/cast.cpp Adds use_4over6 guards for grouped/split quantization paths; propagates the flag to QuantizationConfig; NVTE_USE_FAST_MATH env-var read is added only in the split path, leaving single-tensor quantization unable to enable fast math.
transformer_engine/pytorch/csrc/quantizer.cpp Reads use_4over6 from Python quantizer, sets it on QuantizationConfig and all output tensor wrappers; validates incompatible flags (RHT, 2D, stochastic rounding) at quantization time. use_fast_math is never set here, consistent with the split-path-only gap.
transformer_engine/common/recipe/init.py Adds enable_4over6 recipe field and env-var binding; __post_init__ asserts that RHT, stochastic rounding, and 2D quantization are all disabled — assertion messages don't hint which env vars to set to satisfy them.
transformer_engine/pytorch/tensor/nvfp4_tensor.py Adds use_4over6 attribute to NVFP4Quantizer and propagates it through copy(), make_empty(), and NVFP4Tensor.__new__. Straightforward, no issues.
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py Adds _use_4over6 storage attribute, propagates through __new__, copy_from_storage, _get_new_kwargs, and like(). Mode-mismatch check mirrors the existing _row_scaled_nvfp4 pattern.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Refactors row_scaled_nvfp4 from plain attribute to property backed by _row_scaled_nvfp4; adds parallel _use_4over6 / use_4over6 property. GroupedTensor.like_copy updated to use private names.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["NVFP4BlockScaling(enable_4over6=True)"] --> B[NVFP4BlockScalingRecipeState]
    B --> C["NVFP4Quantizer(use_4over6=True)"]
    C --> D{quantize path}
    D -->|single tensor| E["quantize_impl (quantizer.cpp)"]
    D -->|split tensor| F["split_quantize_nvfp4_impl_helper (cast.cpp)"]
    E --> G["QuantizationConfig.set_nvfp4_4over6(true)"]
    F --> G
    F -->|reads env var| H["NVTE_USE_FAST_MATH → config.set_use_fast_math"]
    G --> I{kernel dispatch}
    I -->|tuned 1D| J["quantize_transpose_nvfp4_tuned_1D_kernel<USE_4OVER6=true>"]
    I -->|vector blockwise| K["block_scaled_1d_cast_transpose_kernel<kUse4Over6=true>"]
    J --> L["rowwise_scaling: compute map4+map6 scales, pick lower MSE"]
    K --> M["cvt_fp32_to_fp4_8x_with_mse_rn: err_map4 vs err_map6"]
    L --> N["NVFP4Tensor with _use_4over6=True"]
    M --> N
Loading

Reviews (1): Last reviewed commit: "Initial implementation" | Re-trigger Greptile

Comment on lines 1315 to +1323
need_separate_rng_states, quant_config_list,
dummy_quant_config_list_colwise); // colwise rng states are not needed in this case

for (auto &config : quant_config_list) {
config.set_nvfp4_4over6(quantizer.use_4over6);
}

const auto use_fast_math = transformer_engine::getenv<bool>("NVTE_USE_FAST_MATH");
if (use_fast_math) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NVTE_USE_FAST_MATH only applied in split-quantization path

NVTE_USE_FAST_MATH is read and forwarded to quant_config_list only here in split_quantize_nvfp4_impl_helper. However, in quantizer.cpp::quantize_impl, no equivalent env-var read exists, so the use_fast_math field on the single-tensor QuantizationConfig stays false even when NVTE_USE_FAST_MATH=1. This means the fast-math variant of cvt_fp32_to_fp4_8x_with_mse_rn (USE_FAST_MATH=true) is unreachable for ordinary single-tensor 4over6 quantization — the env var is silently ignored on that path.

Comment on lines +730 to +752
const float x[8] = {
static_cast<float>(smem_vec[2 * (i + 0)].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 0) + 1].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 1)].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 1) + 1].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 2)].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 2) + 1].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 3)].data.elt[smem_idx]),
static_cast<float>(smem_vec[2 * (i + 3) + 1].data.elt[smem_idx]),
};
output_vec_map4[out_idx] =
transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn<
kUseFastMath>(x, encode_scale_map4, scale_inv_map4, global_amax[0], &err_map4);
output_vec_map6[out_idx] =
transformer_engine::dispatch::nvfp4::core::cvt_fp32_to_fp4_8x_with_mse_rn<
kUseFastMath>(x, encode_scale_map6, scale_inv_map6, global_amax[0], &err_map6);
}

if (err_map4 < err_map6) {
scale_inv = scale_inv_map4;
*reinterpret_cast<uint32_t*>(&output_vec.data.elt[0]) = output_vec_map4[0];
*reinterpret_cast<uint32_t*>(&output_vec.data.elt[4]) = output_vec_map4[1];
} else {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Hardcoded [2] array sizes tied to implicit loop-bound assumption

output_vec_map4[2] and output_vec_map6[2] are sized for exactly two uint32_t outputs, which is correct only when kNVecOut / kNVecSMem == 8 (so the loop runs i = 0, 4). The same pattern appears for the transpose path (kNVecOut / kNFP4PerContainer). There is no static_assert verifying that either ratio equals 8. If a future tuning changes kNVecSMem, kNFP4PerContainer, or kNVecOut, the index out_idx = i / 4 will silently write past the end of these stack arrays, corrupting adjacent local state inside the register file.

Comment on lines 515 to +522
assert (
self.backward_override in _BACKWARD_OVERRIDES
), "NVTE_BACKWARD_OVERRIDE must be unset or one of: 'high_precision', 'dequantized'."
if self.enable_4over6:
assert self.disable_rht, "NVFP4 4over6 currently requires RHT to be disabled"
assert (
self.disable_stochastic_rounding
), "NVFP4 4over6 currently requires stochastic rounding to be disabled"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 enable_4over6 silently fails when set via env var while other env vars are not set

enable_4over6 is read from NVTE_NVFP4_ENABLE_4OVER6, but the __post_init__ asserts immediately require disable_rht, disable_stochastic_rounding, and disable_2d_quantization to all be True. A user who sets only NVTE_NVFP4_ENABLE_4OVER6=1 — following a natural reading of the env var docs — gets an AssertionError at recipe construction with no actionable hint. Consider surfacing the required sibling env vars in the assertion message (e.g. "Set NVTE_NVFP4_DISABLE_RHT=1").

zianglih added 10 commits May 8, 2026 21:08
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
if use_bias:
check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size)
if use_bias and te_grouped_linear.single_grouped_bias:
check_grouped_bias(te_grouped_linear, num_gemms, ffn_hidden_size)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this to get the test passed but it seems unrelated to our changes.

zianglih added 9 commits May 9, 2026 18:40
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant