[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964
[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964cyanguwa wants to merge 23 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Greptile SummaryThis PR replaces the hand-maintained backend selection logic in
Confidence Score: 4/5Safe to merge for PyTorch and JAX callers that use the new or updated bindings; external C callers of the deprecated nvte_get_fused_attn_backend will silently regress to NVTE_No_Backend because the wrapper leaves o_format unset. The deprecated transformer_engine/common/fused_attn/fused_attn.cpp — the deprecated nvte_get_fused_attn_backend wrapper (lines 347-378) needs cfg.o_format and cfg.do_format populated before delegating to nvte_get_fused_attn_backend_v2. Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller as Caller (Python/C)
participant v2 as nvte_get_fused_attn_backend_v2
participant probe_f16 as is_supported_f16_fwd/bwd
participant probe_fp8 as is_supported_fp8_fwd/bwd
participant impl as fused_attn_*_impl (null ptrs)
participant cache as cuDNN Graph Cache
Caller->>v2: "NVTEFusedAttnConfig + &message"
v2->>v2: early checks (64-bit ragged offset, THD mask, dtype)
alt FP16/BF16
v2->>probe_f16: is_supported_f16_fwd(cfg, handle)
probe_f16->>impl: "call with devPtr=nullptr"
impl->>cache: build + cache cuDNN graph
cache-->>impl: ok / exception
impl-->>probe_f16: workspace_size or throw
probe_f16-->>v2: empty string (success) or error string
opt is_training
v2->>probe_f16: is_supported_f16_bwd(cfg, handle)
probe_f16->>impl: "call with devPtr=nullptr"
impl->>cache: build + cache cuDNN bwd graph
probe_f16-->>v2: empty string or error string
end
v2-->>Caller: NVTE_F16_arbitrary_seqlen or NVTE_No_Backend
else FP8
v2->>probe_fp8: is_supported_fp8_fwd(cfg, handle)
probe_fp8->>impl: "call with devPtr=nullptr"
impl->>cache: build + cache FP8 fwd graph
probe_fp8-->>v2: empty string or error string
opt is_training
v2->>probe_fp8: is_supported_fp8_bwd(cfg, handle)
probe_fp8-->>v2: empty string or error string
end
v2-->>Caller: NVTE_FP8 or NVTE_No_Backend
end
Reviews (8): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout); | ||
| if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD && | ||
| qkv_format != NVTE_QKV_Format::NVTE_BHSD) { | ||
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; |
There was a problem hiding this comment.
Compilation error: invalid pointer arithmetic in string construction
NVTE_QKV_Format is an unscoped C enum, so "..." + qkv_format performs pointer arithmetic (advancing the const char* literal by the integer value of the enum). The result is a const char*, and then const char* + "." tries to add two pointers, which is ill-formed in C++. This expression will not compile. The same bug exists on line 1380 (is_supported_fp8_bwd). The fix is to use std::to_string(static_cast<int>(qkv_format)) or construct a std::string explicitly.
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + | |
| std::to_string(static_cast<int>(qkv_format)) + "."; | |
| } | |
| size_t workspace_size = 0; |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | ||
| } | ||
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); |
There was a problem hiding this comment.
Same pointer-arithmetic compilation error as in
is_supported_fp8_fwd — adding qkv_format (an integer-valued enum) to a string literal produces a const char*, and then adding "." yields an ill-formed pointer+pointer expression. Use std::to_string to produce a proper string.
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + "."; | |
| } | |
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); | |
| return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + | |
| std::to_string(static_cast<int>(qkv_format)) + "."; | |
| } | |
| const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype); |
| constexpr size_t probe_batch = 1; | ||
| constexpr bool probe_bottom_right_diagonal = false; |
There was a problem hiding this comment.
Hardcoded probe parameters may silently misclassify configurations
probe_batch = 1 and probe_bottom_right_diagonal = false are used for all cudnn-frontend probe calls regardless of the actual values passed by the caller. When attn_mask_type is NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, the probe uses false for bottom_right_diagonal even though the real call will use true, which can cause the probe to greenlight configurations that fail at actual execution time.
| thread_local std::string fused_attn_backend_message_buffer; | ||
|
|
||
| void set_message(const char **message, const std::string &reason) { | ||
| if (message == nullptr) return; | ||
| fused_attn_backend_message_buffer = reason; | ||
| *message = fused_attn_backend_message_buffer.c_str(); | ||
| } |
There was a problem hiding this comment.
Thread-local buffer invalidated on next call on the same thread
*message is set to the .c_str() of the thread-local fused_attn_backend_message_buffer. Any subsequent call to nvte_get_fused_attn_backend on the same thread will clear and overwrite this buffer, invalidating the raw pointer before the caller has a chance to copy it. The internal probe calls currently pass nullptr for message, so the buffer won't be clobbered by them, but this invariant is fragile and easy to break in a future change.
| bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle, | ||
| const char **message); |
There was a problem hiding this comment.
Breaking public API change:
cudnnHandle_t added as a required parameter
nvte_get_fused_attn_backend is a public C API (no namespace, exported symbol). Adding cudnnHandle_t handle as a required parameter, and including <cudnn.h> in the public header, is a breaking change that forces every downstream consumer to hold and pass a cuDNN handle for what was previously a pure-metadata query.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
| // avoid CUDA graph issue with cuDNN <= 9.15 | ||
| if (cudnn_runtime_version <= 91500 && is_training && | ||
| (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && | ||
| (max_seqlen_kv % 128 != 0) && cuda_graph && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK && | ||
| attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) { | ||
| set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN."); | ||
| return NVTE_Fused_Attn_Backend::NVTE_No_Backend; | ||
| } |
There was a problem hiding this comment.
CUDA graph guard now incorrectly rejects FP8 configurations
The CUDA graph check at lines 281–290 was previously scoped exclusively to the F16/BF16 branch; moving it to the top-level pre-filter means it now also rejects FP8 + CUDA graph configurations on cuDNN ≤ 9.15.0 that were previously accepted. Users running FP8 training with CUDA graph capture, BSHD/SBHD layout, non-padding masks, and max_seqlen_kv % 128 != 0 on cuDNN ≤ 9.15.0 will see the backend silently downgrade to NVTE_No_Backend where it used to return NVTE_FP8. If the cuDNN bug does not affect FP8, this is a regression; if it does, the guard should be narrowed with a comment explaining why it applies to FP8.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
|
/te-ci L1 |
| const NVTE_QKV_Format do_format = o_format; | ||
| const NVTE_QKV_Layout dqkv_layout = qkv_layout; |
There was a problem hiding this comment.
Hardcoded
dqkv_layout and do_format in backward probe may build incorrect cached graph
is_supported_f16_bwd hardcodes dqkv_layout = qkv_layout and do_format = q_format. nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters and passes them through to the actual backward kernel — but these are never forwarded into nvte_get_fused_attn_backend. When the activation and gradient layouts differ, the probe builds and caches a cuDNN graph for a configuration that won't be used. More critically, if the config with qkv_layout is unsupported but the config with the true dqkv_layout would be supported, backend selection will falsely return NVTE_No_Backend. The same assumption exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; |
There was a problem hiding this comment.
o_format = q_format assumption can mismatch the actual forward call
is_supported_f16_fwd derives o_format from q_format, but nvte_fused_attn_fwd accepts a separate o_format parameter that is not forwarded into nvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrong o_format. If the graph with o_format=q_format is accepted but the config with the actual o_format is not (or vice versa), nvte_get_fused_attn_backend produces an incorrect backend decision, causing an error when the actual kernel is invoked.
|
|
||
| // For ragged offsets we only support 32-bit prior to cuDNN 9.5 | ||
| // Only used when THD format is requested. | ||
| cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle(); |
There was a problem hiding this comment.
Backend query now builds cuDNN execution graphs as a side effect, altering call semantics
cudnnExecutionPlanManager::Instance().GetHandle() is obtained here, and the subsequent probe calls invoke fused_attn_*_impl with null device pointers, building and caching cuDNN execution plans. This turns a previously cheap metadata query into a graph-build operation that can take hundreds of milliseconds on first call. Code paths that call nvte_get_fused_attn_backend as an availability check — notably FusedAttnHelper.is_fused_attn_kernel_available() from Python/JAX — will pay this cost unexpectedly during model initialisation. The new semantic should be documented in the public API header.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
| const int64_t bias_sq = has_bias ? sq : 0; | ||
| const int64_t bias_skv = has_bias ? skv : 0; | ||
|
|
||
| const NVTE_QKV_Format o_format = q_format; |
There was a problem hiding this comment.
o_format probe hardcoded as q_format, mismatching callers that use a different output format
is_supported_f16_fwd derives o_format = q_format (line 1372) and uses it to build and cache the cuDNN graph. However, nvte_fused_attn_fwd accepts an independent o_format parameter that is never forwarded into nvte_get_fused_attn_backend. When a caller uses an output format different from the query format (e.g. BSHD output from an SBHD_BSHD_BSHD layout), the cached graph was built for q_format, not the actual o_format. If cuDNN accepts the wrong graph but rejects the real one — or vice versa — the backend check produces an incorrect decision, causing an unexpected error at actual kernel invocation.
| const NVTE_QKV_Format o_format = q_format; | ||
| const NVTE_QKV_Format do_format = o_format; | ||
| const NVTE_QKV_Layout dqkv_layout = qkv_layout; |
There was a problem hiding this comment.
Backward probe hardcodes
dqkv_layout = qkv_layout and do_format = o_format, diverging from the actual backward call
is_supported_f16_bwd sets dqkv_layout = qkv_layout and do_format = o_format (where o_format is already fixed to q_format). The actual backward call nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters that are never threaded through nvte_get_fused_attn_backend. When activation and gradient layouts differ, the probe builds a cuDNN graph for a configuration that is never used at runtime. More critically, if the real dqkv_layout is unsupported but the probe's assumed qkv_layout is accepted, the backend check returns NVTE_F16_arbitrary_seqlen and the actual backward pass silently fails. The same issue exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
| NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( | ||
| is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, | ||
| h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, | ||
| return_max_logit, cuda_graph, false); | ||
| is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type, | ||
| softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, | ||
| window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, | ||
| /*deterministic=*/false, /*message=*/nullptr); |
There was a problem hiding this comment.
Probe graph builds triggered on every forward and backward execution step
nvte_fused_attn_fwd and nvte_fused_attn_bwd both now call nvte_get_fused_attn_backend internally, which in turn calls the probe functions (is_supported_f16_fwd, is_supported_f16_bwd, etc.) with null device pointers. The cudnn-frontend cache is keyed on the full graph specification including batch_size. When the runtime batch size b (extracted from the actual tensor shapes) differs from the value used in the user-facing availability check — for example the last mini-batch in an epoch or an inference request with a different batch size — none of the graphs built during the pre-check will match, and a full graph build is triggered mid-execution. These builds can take hundreds of milliseconds and will stall the training step unexpectedly. Designs to mitigate this include: caching a "any batch size" graph variant, or making nvte_fused_attn_fwd/nvte_fused_attn_bwd route to the backend without re-invoking the probes.
| @@ -146,6 +157,7 @@ def get_fused_attn_backend(self): | |||
| self.head_dim_v, | |||
| self.window_size[0], | |||
| self.window_size[1], | |||
| self.bottom_right_diagonal, | |||
| not self.is_non_deterministic_allowed(), | |||
| ) | |||
There was a problem hiding this comment.
o_type and scaling_mode are always fixed regardless of JAX FP8 DPA config
FusedAttnHelper.get_fused_attn_backend() always passes o_type = q_type and scaling_mode = NVTEScalingMode.NVTE_INVALID_SCALING. For non-FP8 JAX attention this is correct, but the same helper is the one called from FusedAttnFwdPrimitive.abstract to determine which backend to activate. If a JAX caller configures FP8 DPA (e.g. FP8 q_dtype with BF16 output), the probe will enter the FP8 branch of nvte_get_fused_attn_backend with o_dtype=q_dtype (wrong) and scaling_mode=NVTE_INVALID_SCALING (likely unsupported by the FP8 cuDNN graphs), causing the probe to return NVTE_No_Backend and silently falling back to unfused attention even though the configuration would be supported with the correct o_type and scaling_mode. The FusedAttnHelper dataclass should expose o_dtype and scaling_mode fields analogous to what utils.py computes for PyTorch.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
| NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend( | ||
| is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout, | ||
| h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right, | ||
| return_max_logit, cuda_graph, false); | ||
| is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format, | ||
| /*do_format=*/o_format, /*dqkv_layout=*/qkv_layout, qkv_scale_inv_format, | ||
| /*do_scale_inv_format=*/qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type, | ||
| attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, | ||
| window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph, | ||
| /*deterministic=*/false, /*message=*/nullptr); |
There was a problem hiding this comment.
Forward path re-probes backward with
deterministic=false, wasting graph cache entries for deterministic training
nvte_fused_attn_fwd calls nvte_get_fused_attn_backend with is_training=is_training and hardcoded deterministic=false. When a user trains with deterministic=true, the forward invocation builds and caches a cuDNN graph for the non-deterministic backward. A few lines later, nvte_fused_attn_bwd calls nvte_get_fused_attn_backend again with deterministic=true, causing the backward graph to be rebuilt from scratch on every step. The deterministic flag should be threaded through to this site (as it already is in the nvte_fused_attn_bwd call), or the forward pass should skip the backward probe entirely and let nvte_fused_attn_bwd drive its own backend check.
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
for more information, see https://pre-commit.ci
| NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT; | ||
| cfg.qkv_layout = qkv_layout; | ||
| cfg.dqkv_layout = qkv_layout; // legacy: gradient layout matches input layout | ||
| cfg.bias_type = bias_type; | ||
| cfg.attn_mask_type = attn_mask_type; | ||
| cfg.softmax_type = softmax_type; | ||
| cfg.attn_scale = 1.0f; // legacy default; matches the value pre-PR probes hardcoded | ||
| cfg.dropout = dropout; | ||
| cfg.max_seqlen_q = max_seqlen_q; | ||
| cfg.max_seqlen_kv = max_seqlen_kv; | ||
| cfg.window_size_left = window_size_left; | ||
| cfg.window_size_right = window_size_right; | ||
| cfg.cuda_graph = cuda_graph; | ||
| NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type."); | ||
| cfg.qkv_dtype = q_dtype; | ||
| cfg.o_dtype = q_dtype; // legacy: O dtype matches Q dtype | ||
| cfg.batch_size = 1; // legacy: pre-PR probes assumed batch=1 | ||
| cfg.num_attn_heads = num_attn_heads; | ||
| cfg.num_gqa_groups = num_gqa_groups; | ||
| cfg.head_dim_qk = head_dim_qk; | ||
| cfg.head_dim_v = head_dim_v; | ||
| cfg.is_training = is_training; | ||
| cfg.return_max_logit = return_max_logit; | ||
| cfg.deterministic = deterministic; | ||
| return nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr); | ||
| } |
There was a problem hiding this comment.
Deprecated wrapper omits
cfg.o_format and cfg.do_format
nvte_get_fused_attn_backend builds an NVTEFusedAttnConfig via NVTE_FUSED_ATTN_CONFIG_INIT but never sets cfg.o_format or cfg.do_format, leaving both as NVTE_QKV_Format_NOT_SET. The probes (is_supported_f16_fwd, is_supported_f16_bwd, is_supported_fp8_fwd) then pass this sentinel directly to fused_attn_*_impl. If cuDNN rejects NVTE_QKV_Format_NOT_SET as an output format — which is the expected behaviour for an unrecognised enum value — every call through the deprecated API returns NVTE_No_Backend for configurations that were previously accepted. A minimal fix is to derive the missing fields from qkv_layout the same way GetFusedAttnBackend (JAX) does: cfg.o_format = nvte_get_q_format(qkv_layout); cfg.do_format = cfg.o_format;.
|
/te-ci L1 |
Description
This PR replaces the hand-maintained backend selection logic in
nvte_get_fused_attn_backendwith cudnn-frontend's production-grade support checks.nvte_get_fused_attn_backendcall.Type of change
Changes
See Description.
Checklist: