Skip to content

[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964

Open
cyanguwa wants to merge 23 commits intoNVIDIA:mainfrom
cyanguwa:fe_check_support
Open

[All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls#2964
cyanguwa wants to merge 23 commits intoNVIDIA:mainfrom
cyanguwa:fe_check_support

Conversation

@cyanguwa
Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa commented May 6, 2026

Description

This PR replaces the hand-maintained backend selection logic in nvte_get_fused_attn_backend with cudnn-frontend's production-grade support checks.

  • It will build the same graph that runtime uses in execution, cache the graph if the build is successful, and provide a warmed-up cache for the next nvte_get_fused_attn_backend call.
  • It provides a cleaner dispatch logic, avoids accidental regressions, and helps TE to stay in sync with cudnn-frontend's support surface.
  • It also provides appropriate error messaging for when no backend is available, so users are notified and make config/architecture/cudnn version adjustments.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

See Description.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

cyanguwa and others added 4 commits May 5, 2026 18:55
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa changed the title [Common] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls [All] Refactor nvte_get_fused_attn_backend with cudnn-frontend calls May 8, 2026
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa cyanguwa marked this pull request as ready for review May 8, 2026 00:10
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR replaces the hand-maintained backend selection logic in nvte_get_fused_attn_backend with cudnn-frontend's production-grade support checks by building and caching the same cuDNN execution graph that the runtime uses. The old function is deprecated in favour of a new versioned nvte_get_fused_attn_backend_v2 that accepts a richer NVTEFusedAttnConfig struct and returns a diagnostic message when no backend matches.

  • New C API (nvte_get_fused_attn_backend_v2): introduces NVTEFusedAttnConfig with a versioning struct_size guard, NVTE_FUSED_ATTN_CONFIG_INIT macro, and an is_supported_f16/fp8_fwd/bwd probe family that builds cuDNN graphs with null device pointers to verify support.
  • PyTorch / JAX bindings updated: get_fused_attn_backend in both frameworks now returns (backend, message), threads FP8 scaling_mode, o_type, o_format, batch_size, bottom_right_diagonal, and attn_scale through to the probe; the JAX C++ extension normalises NVTE_QKV_Format_NOT_SET defaults before forwarding to the core API.
  • Deprecated wrapper (nvte_get_fused_attn_backend): preserved for ABI compatibility but leaves cfg.o_format and cfg.do_format unset (NVTE_QKV_Format_NOT_SET), which the cuDNN probes receive verbatim and may reject — causing a regression for external C callers of the deprecated API.

Confidence Score: 4/5

Safe to merge for PyTorch and JAX callers that use the new or updated bindings; external C callers of the deprecated nvte_get_fused_attn_backend will silently regress to NVTE_No_Backend because the wrapper leaves o_format unset.

The deprecated nvte_get_fused_attn_backend wrapper constructs an NVTEFusedAttnConfig via NVTE_FUSED_ATTN_CONFIG_INIT but never assigns cfg.o_format or cfg.do_format, so both remain as NVTE_QKV_Format_NOT_SET. The cuDNN probe receives this sentinel value as the output format and is likely to throw, causing the wrapper to return NVTE_No_Backend for configurations that previously returned NVTE_F16_arbitrary_seqlen or NVTE_FP8. All first-party Python callers (PyTorch and JAX) route through updated bindings that set every field correctly and are unaffected.

transformer_engine/common/fused_attn/fused_attn.cpp — the deprecated nvte_get_fused_attn_backend wrapper (lines 347-378) needs cfg.o_format and cfg.do_format populated before delegating to nvte_get_fused_attn_backend_v2.

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn.cpp Core backend selection rewritten to delegate to cuDNN probe calls; deprecated wrapper missing o_format/do_format initialization is a regression for external C API callers; fwd path hardcodes deterministic=false when building the backward probe graph
transformer_engine/common/include/transformer_engine/fused_attn.h Introduces NVTEFusedAttnConfig versioned struct and nvte_get_fused_attn_backend_v2; adds cudnnHandle_t dependency to the public C header; NVTE_FUSED_ATTN_CONFIG_INIT macro zero-initializes optional format fields correctly
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Adds is_supported_f16_fwd/bwd probe functions that correctly read o_format, do_format, dqkv_layout, and bottom_right_diagonal from the config struct; probes build and cache cuDNN graphs with null device pointers
transformer_engine/common/fused_attn/fused_attn_fp8.cu Adds is_supported_fp8_fwd/bwd probe functions; both correctly propagate do_format, dqkv_layout, and scaling_mode from the config struct; do_dtype in bwd is derived from o_dtype, which is consistent with existing assumptions
transformer_engine/pytorch/attention/dot_product_attention/utils.py FP8 scaling_mode, o_type, and qkv_scale_inv_format now correctly propagated to the backend probe; QKVFormat[None] safely resolves to NVTE_QKV_Format_NOT_SET for non-MXFP8 paths
transformer_engine/jax/cpp_extensions/attention.py FusedAttnHelper updated with batch_size, bottom_right_diagonal, and attn_scale; get_fused_attn_backend now returns (backend, message) tuple; o_type fixed to q_type and scaling_mode to NVTE_INVALID_SCALING for non-FP8 JAX paths, which is correct
transformer_engine/jax/csrc/extensions/attention.cpp GetFusedAttnBackend correctly normalizes NVTE_QKV_Format_NOT_SET to q_format for o_format and do_format; pybind bindings updated to expose new NVTE_QKV_Layout/Format NOT_SET values and NVTEScalingMode enum

Sequence Diagram

sequenceDiagram
    participant Caller as Caller (Python/C)
    participant v2 as nvte_get_fused_attn_backend_v2
    participant probe_f16 as is_supported_f16_fwd/bwd
    participant probe_fp8 as is_supported_fp8_fwd/bwd
    participant impl as fused_attn_*_impl (null ptrs)
    participant cache as cuDNN Graph Cache

    Caller->>v2: "NVTEFusedAttnConfig + &message"
    v2->>v2: early checks (64-bit ragged offset, THD mask, dtype)
    alt FP16/BF16
        v2->>probe_f16: is_supported_f16_fwd(cfg, handle)
        probe_f16->>impl: "call with devPtr=nullptr"
        impl->>cache: build + cache cuDNN graph
        cache-->>impl: ok / exception
        impl-->>probe_f16: workspace_size or throw
        probe_f16-->>v2: empty string (success) or error string
        opt is_training
            v2->>probe_f16: is_supported_f16_bwd(cfg, handle)
            probe_f16->>impl: "call with devPtr=nullptr"
            impl->>cache: build + cache cuDNN bwd graph
            probe_f16-->>v2: empty string or error string
        end
        v2-->>Caller: NVTE_F16_arbitrary_seqlen or NVTE_No_Backend
    else FP8
        v2->>probe_fp8: is_supported_fp8_fwd(cfg, handle)
        probe_fp8->>impl: "call with devPtr=nullptr"
        impl->>cache: build + cache FP8 fwd graph
        probe_fp8-->>v2: empty string or error string
        opt is_training
            v2->>probe_fp8: is_supported_fp8_bwd(cfg, handle)
            probe_fp8-->>v2: empty string or error string
        end
        v2-->>Caller: NVTE_FP8 or NVTE_No_Backend
    end
Loading

Reviews (8): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile

const NVTE_QKV_Format qkv_format = nvte_get_qkv_format(qkv_layout);
if (qkv_format != NVTE_QKV_Format::NVTE_BSHD && qkv_format != NVTE_QKV_Format::NVTE_SBHD &&
qkv_format != NVTE_QKV_Format::NVTE_BHSD) {
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Compilation error: invalid pointer arithmetic in string construction

NVTE_QKV_Format is an unscoped C enum, so "..." + qkv_format performs pointer arithmetic (advancing the const char* literal by the integer value of the enum). The result is a const char*, and then const char* + "." tries to add two pointers, which is ill-formed in C++. This expression will not compile. The same bug exists on line 1380 (is_supported_fp8_bwd). The fix is to use std::to_string(static_cast<int>(qkv_format)) or construct a std::string explicitly.

Suggested change
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " +
std::to_string(static_cast<int>(qkv_format)) + ".";
}
size_t workspace_size = 0;

Comment on lines +1380 to +1382
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P0 Same pointer-arithmetic compilation error as in is_supported_fp8_fwd — adding qkv_format (an integer-valued enum) to a string literal produces a const char*, and then adding "." yields an ill-formed pointer+pointer expression. Use std::to_string to produce a proper string.

Suggested change
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " + qkv_format + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);
return "FP8 fused attention supports BSHD/SBHD/BHSD formats, found " +
std::to_string(static_cast<int>(qkv_format)) + ".";
}
const cudnn_frontend::DataType_t qkv_t = get_cudnn_fe_dtype(qkv_dtype);

Comment on lines +292 to +293
constexpr size_t probe_batch = 1;
constexpr bool probe_bottom_right_diagonal = false;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded probe parameters may silently misclassify configurations

probe_batch = 1 and probe_bottom_right_diagonal = false are used for all cudnn-frontend probe calls regardless of the actual values passed by the caller. When attn_mask_type is NVTE_CAUSAL_BOTTOM_RIGHT_MASK or NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK, the probe uses false for bottom_right_diagonal even though the real call will use true, which can cause the probe to greenlight configurations that fail at actual execution time.

Comment on lines +232 to +238
thread_local std::string fused_attn_backend_message_buffer;

void set_message(const char **message, const std::string &reason) {
if (message == nullptr) return;
fused_attn_backend_message_buffer = reason;
*message = fused_attn_backend_message_buffer.c_str();
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Thread-local buffer invalidated on next call on the same thread

*message is set to the .c_str() of the thread-local fused_attn_backend_message_buffer. Any subsequent call to nvte_get_fused_attn_backend on the same thread will clear and overwrite this buffer, invalidating the raw pointer before the caller has a chance to copy it. The internal probe calls currently pass nullptr for message, so the buffer won't be clobbered by them, but this invariant is fragile and easy to break in a future change.

Comment on lines +232 to +233
bool return_max_logit, bool cuda_graph, bool deterministic, cudnnHandle_t handle,
const char **message);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Breaking public API change: cudnnHandle_t added as a required parameter

nvte_get_fused_attn_backend is a public C API (no namespace, exported symbol). Adding cudnnHandle_t handle as a required parameter, and including <cudnn.h> in the public header, is a breaking change that forces every downstream consumer to hold and pass a cuDNN handle for what was previously a pure-metadata query.

cyanguwa and others added 2 commits May 7, 2026 17:22
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Comment on lines +281 to +290
// avoid CUDA graph issue with cuDNN <= 9.15
if (cudnn_runtime_version <= 91500 && is_training &&
(qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) &&
(max_seqlen_kv % 128 != 0) && cuda_graph &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK &&
attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK) {
set_message(message, "Known cuDNN <= 9.15 issue with CUDA graph. Please upgrade cuDNN.");
return NVTE_Fused_Attn_Backend::NVTE_No_Backend;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 CUDA graph guard now incorrectly rejects FP8 configurations

The CUDA graph check at lines 281–290 was previously scoped exclusively to the F16/BF16 branch; moving it to the top-level pre-filter means it now also rejects FP8 + CUDA graph configurations on cuDNN ≤ 9.15.0 that were previously accepted. Users running FP8 training with CUDA graph capture, BSHD/SBHD layout, non-padding masks, and max_seqlen_kv % 128 != 0 on cuDNN ≤ 9.15.0 will see the backend silently downgrade to NVTE_No_Backend where it used to return NVTE_FP8. If the cuDNN bug does not affect FP8, this is a regression; if it does, the guard should be narrowed with a comment explaining why it applies to FP8.

cyanguwa and others added 3 commits May 7, 2026 18:30
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 8, 2026

/te-ci L1

Comment on lines +1427 to +1428
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded dqkv_layout and do_format in backward probe may build incorrect cached graph

is_supported_f16_bwd hardcodes dqkv_layout = qkv_layout and do_format = q_format. nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters and passes them through to the actual backward kernel — but these are never forwarded into nvte_get_fused_attn_backend. When the activation and gradient layouts differ, the probe builds and caches a cuDNN graph for a configuration that won't be used. More critically, if the config with qkv_layout is unsupported but the config with the true dqkv_layout would be supported, backend selection will falsely return NVTE_No_Backend. The same assumption exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.

const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format = q_format assumption can mismatch the actual forward call

is_supported_f16_fwd derives o_format from q_format, but nvte_fused_attn_fwd accepts a separate o_format parameter that is not forwarded into nvte_get_fused_attn_backend. When the caller uses a different output format from the query format — such as returning BSHD output from an SBHD_BSHD_BSHD layout — the probe builds a cuDNN graph for the wrong o_format. If the graph with o_format=q_format is accepted but the config with the actual o_format is not (or vice versa), nvte_get_fused_attn_backend produces an incorrect backend decision, causing an error when the actual kernel is invoked.


// For ragged offsets we only support 32-bit prior to cuDNN 9.5
// Only used when THD format is requested.
cudnnHandle_t handle = cudnnExecutionPlanManager::Instance().GetHandle();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backend query now builds cuDNN execution graphs as a side effect, altering call semantics

cudnnExecutionPlanManager::Instance().GetHandle() is obtained here, and the subsequent probe calls invoke fused_attn_*_impl with null device pointers, building and caching cuDNN execution plans. This turns a previously cheap metadata query into a graph-build operation that can take hundreds of milliseconds on first call. Code paths that call nvte_get_fused_attn_backend as an availability check — notably FusedAttnHelper.is_fused_attn_kernel_available() from Python/JAX — will pay this cost unexpectedly during model initialisation. The new semantic should be documented in the public API header.

Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
const int64_t bias_sq = has_bias ? sq : 0;
const int64_t bias_skv = has_bias ? skv : 0;

const NVTE_QKV_Format o_format = q_format;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_format probe hardcoded as q_format, mismatching callers that use a different output format

is_supported_f16_fwd derives o_format = q_format (line 1372) and uses it to build and cache the cuDNN graph. However, nvte_fused_attn_fwd accepts an independent o_format parameter that is never forwarded into nvte_get_fused_attn_backend. When a caller uses an output format different from the query format (e.g. BSHD output from an SBHD_BSHD_BSHD layout), the cached graph was built for q_format, not the actual o_format. If cuDNN accepts the wrong graph but rejects the real one — or vice versa — the backend check produces an incorrect decision, causing an unexpected error at actual kernel invocation.

Comment on lines +1426 to +1428
const NVTE_QKV_Format o_format = q_format;
const NVTE_QKV_Format do_format = o_format;
const NVTE_QKV_Layout dqkv_layout = qkv_layout;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Backward probe hardcodes dqkv_layout = qkv_layout and do_format = o_format, diverging from the actual backward call

is_supported_f16_bwd sets dqkv_layout = qkv_layout and do_format = o_format (where o_format is already fixed to q_format). The actual backward call nvte_fused_attn_bwd accepts independent dqkv_layout and do_format parameters that are never threaded through nvte_get_fused_attn_backend. When activation and gradient layouts differ, the probe builds a cuDNN graph for a configuration that is never used at runtime. More critically, if the real dqkv_layout is unsupported but the probe's assumed qkv_layout is accepted, the backend check returns NVTE_F16_arbitrary_seqlen and the actual backward pass silently fails. The same issue exists in is_supported_fp8_bwd at line 1385 of fused_attn_fp8.cu.

cyanguwa and others added 3 commits May 7, 2026 22:28
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Comment on lines +444 to +448
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit, cuda_graph, false);
is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, bias_type, attn_mask_type,
softmax_type, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left,
window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph,
/*deterministic=*/false, /*message=*/nullptr);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Probe graph builds triggered on every forward and backward execution step

nvte_fused_attn_fwd and nvte_fused_attn_bwd both now call nvte_get_fused_attn_backend internally, which in turn calls the probe functions (is_supported_f16_fwd, is_supported_f16_bwd, etc.) with null device pointers. The cudnn-frontend cache is keyed on the full graph specification including batch_size. When the runtime batch size b (extracted from the actual tensor shapes) differs from the value used in the user-facing availability check — for example the last mini-batch in an epoch or an inference request with a different batch size — none of the graphs built during the pre-check will match, and a full graph build is triggered mid-execution. These builds can take hundreds of milliseconds and will stall the training step unexpectedly. Designs to mitigate this include: caching a "any batch size" graph variant, or making nvte_fused_attn_fwd/nvte_fused_attn_bwd route to the backend without re-invoking the probes.

Comment on lines 140 to 162
@@ -146,6 +157,7 @@ def get_fused_attn_backend(self):
self.head_dim_v,
self.window_size[0],
self.window_size[1],
self.bottom_right_diagonal,
not self.is_non_deterministic_allowed(),
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 o_type and scaling_mode are always fixed regardless of JAX FP8 DPA config

FusedAttnHelper.get_fused_attn_backend() always passes o_type = q_type and scaling_mode = NVTEScalingMode.NVTE_INVALID_SCALING. For non-FP8 JAX attention this is correct, but the same helper is the one called from FusedAttnFwdPrimitive.abstract to determine which backend to activate. If a JAX caller configures FP8 DPA (e.g. FP8 q_dtype with BF16 output), the probe will enter the FP8 branch of nvte_get_fused_attn_backend with o_dtype=q_dtype (wrong) and scaling_mode=NVTE_INVALID_SCALING (likely unsupported by the FP8 cuDNN graphs), causing the probe to return NVTE_No_Backend and silently falling back to unfused attention even though the configuration would be supported with the correct o_type and scaling_mode. The FusedAttnHelper dataclass should expose o_dtype and scaling_mode fields analogous to what utils.py computes for PyTorch.

cyanguwa and others added 2 commits May 8, 2026 12:19
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Comment on lines +454 to +460
NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
is_training, Q_type, KV_type, qkv_layout, bias_type, attn_mask_type, softmax_type, dropout,
h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left, window_size_right,
return_max_logit, cuda_graph, false);
is_training, b, Q_type, KV_type, O_type, scaling_mode, qkv_layout, o_format,
/*do_format=*/o_format, /*dqkv_layout=*/qkv_layout, qkv_scale_inv_format,
/*do_scale_inv_format=*/qkv_scale_inv_format, bias_type, attn_mask_type, softmax_type,
attn_scale, dropout, h_q, h_kv, max_seqlen_q, max_seqlen_kv, d_qk, d_v, window_size_left,
window_size_right, bottom_right_diagonal, return_max_logit, cuda_graph,
/*deterministic=*/false, /*message=*/nullptr);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Forward path re-probes backward with deterministic=false, wasting graph cache entries for deterministic training

nvte_fused_attn_fwd calls nvte_get_fused_attn_backend with is_training=is_training and hardcoded deterministic=false. When a user trains with deterministic=true, the forward invocation builds and caches a cuDNN graph for the non-deterministic backward. A few lines later, nvte_fused_attn_bwd calls nvte_get_fused_attn_backend again with deterministic=true, causing the backward graph to be rebuilt from scratch on every step. The deterministic flag should be threaded through to this site (as it already is in the nvte_fused_attn_bwd call), or the forward pass should skip the backward probe entirely and let nvte_fused_attn_bwd drive its own backend check.

cyanguwa and others added 4 commits May 8, 2026 12:35
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Comment on lines +353 to 378
NVTEFusedAttnConfig cfg = NVTE_FUSED_ATTN_CONFIG_INIT;
cfg.qkv_layout = qkv_layout;
cfg.dqkv_layout = qkv_layout; // legacy: gradient layout matches input layout
cfg.bias_type = bias_type;
cfg.attn_mask_type = attn_mask_type;
cfg.softmax_type = softmax_type;
cfg.attn_scale = 1.0f; // legacy default; matches the value pre-PR probes hardcoded
cfg.dropout = dropout;
cfg.max_seqlen_q = max_seqlen_q;
cfg.max_seqlen_kv = max_seqlen_kv;
cfg.window_size_left = window_size_left;
cfg.window_size_right = window_size_right;
cfg.cuda_graph = cuda_graph;
NVTE_CHECK(q_dtype == kv_dtype, "Q and KV must have the same data type.");
cfg.qkv_dtype = q_dtype;
cfg.o_dtype = q_dtype; // legacy: O dtype matches Q dtype
cfg.batch_size = 1; // legacy: pre-PR probes assumed batch=1
cfg.num_attn_heads = num_attn_heads;
cfg.num_gqa_groups = num_gqa_groups;
cfg.head_dim_qk = head_dim_qk;
cfg.head_dim_v = head_dim_v;
cfg.is_training = is_training;
cfg.return_max_logit = return_max_logit;
cfg.deterministic = deterministic;
return nvte_get_fused_attn_backend_v2(&cfg, /*message=*/nullptr);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Deprecated wrapper omits cfg.o_format and cfg.do_format

nvte_get_fused_attn_backend builds an NVTEFusedAttnConfig via NVTE_FUSED_ATTN_CONFIG_INIT but never sets cfg.o_format or cfg.do_format, leaving both as NVTE_QKV_Format_NOT_SET. The probes (is_supported_f16_fwd, is_supported_f16_bwd, is_supported_fp8_fwd) then pass this sentinel directly to fused_attn_*_impl. If cuDNN rejects NVTE_QKV_Format_NOT_SET as an output format — which is the expected behaviour for an unrecognised enum value — every call through the deprecated API returns NVTE_No_Backend for configurations that were previously accepted. A minimal fix is to derive the missing fields from qkv_layout the same way GetFusedAttnBackend (JAX) does: cfg.o_format = nvte_get_q_format(qkv_layout); cfg.do_format = cfg.o_format;.

@cyanguwa
Copy link
Copy Markdown
Collaborator Author

cyanguwa commented May 8, 2026

/te-ci L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant