Skip to content

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965

Open
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr
Open

[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:sudhakars/cp_test_batching_pr

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented May 6, 2026

Design

Problem

8×H100, test_essential=True, 38 runnable CP attention configs:

  • ~554 s wall when each config runs in its own torchrun.
  • Of that, only ~216 s is the test work itself (CP fwd/bwd + assertions, ~5.7 s/config).
  • The remaining ~338 s is per-spawn overhead — Python imports + NCCL global init/teardown — paid 38 times at ~8.9 s each.

We need that overhead amortised, without changing how tests are written or how skips report.

Approach

A session-scoped fixture (_cp_batch_results) does two passes:

  1. Collect (dry-run, in-process). Walk pytest's collected items. For each item that requests _cp_batch_results, call its test function directly with a stubbed request. The body executes its inline pytest.skip(...) checks normally; if any fires, the item is dropped from the batch. Otherwise the body's final call to _run_or_fetch(...) records its kwargs in a module-level dict instead of launching a subprocess.
  2. Batch + execute. Group recorded kwargs by num_gpus_per_node, chunk into batches of CP_TEST_BATCH_SIZE (default 16), launch one torchrun per chunk. Worker (run_attention_with_cp.py) inits NCCL once, loops over configs, atomically flushes per-config results to <batch>.results.json. When pytest later runs each test for real, the body re-evaluates skips and _run_or_fetch looks up the recorded result.

How dry-run works

@pytest.fixture(scope="session")
def _cp_batch_results(request):
    items = [it for it in request.session.items
             if "_cp_batch_results" in getattr(it, "fixturenames", ())]
    _COLLECT_MODE = True
    for item in items:
        if _item_static_skip(item):
            continue
        try:
            _dry_run_item(item)
        except pytest.skip.Exception:
            pass
        except BaseException:
            pass  # surfaces in execute mode as a normal pytest error
    _COLLECT_MODE = False
    # group _COLLECTED_KWARGS by num_gpus, chunk, run torchrun batches

_dry_run_item calls the underlying function with the same parametrize values pytest would have passed:

def _dry_run_item(item):
    func = item.function
    params = dict(item.callspec.params)
    func(_DummyRequest(item.nodeid), {}, **params)

This bypasses pytest's runner entirely — no fixture setup hooks, no plugin reporters, no captured-stdout machinery.

_run_or_fetch checks a module-level _COLLECT_MODE flag:

def _run_or_fetch(request, batch_results, *, num_gpus_per_node, **worker_kwargs):
    if _COLLECT_MODE:
        _COLLECTED_KWARGS[request.node.nodeid] = dict(num_gpus=num_gpus_per_node, **worker_kwargs)
        return  # never reaches the lookup; never asserts
    entry = batch_results.get(request.node.nodeid)
    ...

In collect mode it's a recorder, in execute mode it's a result-fetcher. The test body doesn't know which mode it's in — it just calls one helper at the end.

Stubs and skip handling

Param Stub Why
request _DummyRequest(nodeid) — only request.node.nodeid _run_or_fetch only reads nodeid; test body never touches request.
_cp_batch_results {} (empty dict) _run_or_fetch returns early in collect mode, never inspects batch_results.

Inline pytest.skip("reason") raises pytest.skip.Exception. The dry-run loop catches per-item, drops the item from the batch, and moves on. In execute mode the same line raises again; pytest reports SKIPPED with the same reason.

@pytest.mark.skip and @pytest.mark.skipif(<bool_condition>) markers don't fire when calling item.function(...) directly. _item_static_skip(item) walks item.iter_markers("skip"|"skipif") and reads marker.args[0] (the condition) before the dry-run, dropping items the markers would otherwise skip.

Cost of running the body twice

For each item the body runs once during dry-run and once during execute. Skip checks are pure Python; the only non-trivial work is get_available_attention_backends, cached per nodeid via _BACKEND_CACHE so the second call is a dict hit. Measured on full test_essential=True (10272 collected items, 38 runnable): 530 cache lookups, 0.03 s total.

End-to-end pytest overhead (dry-run + collection, with torchrun stubbed): ~14 s wall, of which ~6.6 s is module-import startup, ~3.2 s is pytest per-item setup, 0.2 s is the batching infra itself. Negligible vs the GPU work it dispatches.

Performance

8×H100, test_essential=True (38 runnable configs: 34 × 2-GPU + 4 × 4-GPU). In unbatched mode each config is its own torchrun. In batched mode, configs sharing the same num_gpus_per_node are grouped into one torchrun of up to CP_TEST_BATCH_SIZE configs.

Run Torchrun spawns Wall Speedup
Unbatched 38 (one per config) ~554 s 1.0×
B=16 4 (16+16+2 @ 2GPU; 4 @ 4GPU) 274 s 2.0×
B=32 3 (32+2 @ 2GPU; 4 @ 4GPU) 248 s 2.2×
B=50 2 (34 @ 2GPU; 4 @ 4GPU) 237 s 2.3×

Where the 2× comes from

  1. 34 fewer torchrun spawns. Each saved spawn cuts ~12 s of startup (Python imports + NCCL global init/teardown) — measured directly from the wall-time delta B=50 → B=16 (37 s saved across 2 fewer spawns).
  2. ~1.2 s lower per-config work (~46 s total). Sharing NCCL global state across configs in a batch drops per-config wall from 5.7 s to 4.5 s; only the per-config CP comm-group create/destroy remains.

The spawn savings dominate.

Picking CP_TEST_BATCH_SIZE

For these 38 configs (34 × 2-GPU + 4 × 4-GPU):

  • B=16 → B=32: −27 s (one fewer 2-GPU spawn).
  • B=32 → B=50: −11 s (one more, but the merged batch is now large enough that bookkeeping eats into the savings).

16 and 32 are ballparks for this matrix — once B exceeds the largest GPU-group size (here, 34), all configs in that group already share a torchrun and further increases do nothing. With a larger config matrix (e.g. full test_essential=False ≈ 348 runnable), the same logic implies B should scale up too: pick it so the largest GPU-group has only a small number of torchrun spawns, but not so large that a single batch becomes long enough that a worker crash loses too much progress.

Knobs

Env var Effect
CP_TEST_BATCH_SIZE=N Configs per torchrun. Default 16. Set 1 to bisect.
CP_TEST_BATCH_RETRY=0 Disable singleton retry for unattributed crashes.

Adding a batched test

  1. Write the test the way you would any CP test: @pytest.mark.parametrize stack + inline pytest.skip(...) checks.
  2. Add request, _cp_batch_results to the function signature.
  3. Replace the trailing run_distributed(get_bash_arguments(...)) with _run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...) (kwargs become the worker's run_dpa_with_cp(**kwargs) arguments).

That's the entire wiring.

Failure semantics

Outcome What pytest sees
Inline pytest.skip(...) fires Standard SKIP (re-evaluated in execute mode and short-circuits before _run_or_fetch).
@pytest.mark.skip(if) marker fires Standard SKIP via pytest's normal path (not queued for torchrun).
Config ran, assertion failed FAIL with worker's traceback.
Assertion fired on rank > 0 only FAIL via cross-rank dist.all_reduce(ok, op=MIN).
Worker subprocess crashed before flush Each affected config retried as a singleton; real result wins, residual crashes surface as FAIL with attribution.
Dry-run itself raised Caught and ignored in the fixture; same exception fires in execute mode and pytest reports it as a normal test ERROR.

Mitigations for shared-process state

Configs in a batch share one Python process and one NCCL world, so anything that needs a clean per-test starting point is reset explicitly:

  • Per-config NCCL sub-group destruction (cp_comm_group, a2a+p2p sub-groups).
  • Reset _TRANSIENT_ENV_KEYS between configs (NVTE_FLASH_ATTN, NVTE_FUSED_ATTN, NVTE_FP8_DPA_BWD, NVTE_DPA_FP8CS_O_in_F16, NVTE_ALLOW_NONDETERMINISTIC_ALGO).
  • torch.cuda.empty_cache(), dist.barrier() between configs.
  • RNG re-seed (1234) at start of each config.
  • copy.deepcopy(model_configs_*[model]) in the worker (THD path mutates attn_mask_type).
  • Atomic per-config flush (tmp + os.replace): a partial JSON is never visible to the reader.
  • Cross-rank dist.all_reduce(ok, op=MIN) after each config so any rank's failure flips ok to False.
  • Auto-retry crashed batch entries as singletons; disable via CP_TEST_BATCH_RETRY=0.
  • arg.split("=", 1) so kwarg values containing = (paths) survive.

Edge cases

  1. request API surface during dry-run. Only request.node.nodeid is provided. A future test that uses request.config.getoption(...) or request.getfixturevalue(...) would AttributeError during dry-run. The fixture catches BaseException so the same error fires in execute mode where pytest's real request is available.
  2. @pytest.mark.skipif(condition_evaluated_at_runtime). A skipif whose condition becomes True only at execute time would not be detected by _item_static_skip. The condition still fires correctly in execute mode; we'd just have wasted one batch slot for it.
  3. get_available_attention_backends non-determinism. If this returns different values between dry-run and execute (driver state changes), a config queued by collect could skip in execute. Harmless: _run_or_fetch is never reached, the unused batch result is garbage-collected.
  4. Pytest internals. The dry-run uses item.function, item.callspec.params, and pytest.skip.Exception. Stable in pytest 7+/8+. If they shift, _dry_run_item is a 3-line shim to update.

Validation

8×H100, test_essential=True: 38 passed / 10234 skipped / 0 unrelated failures.

Stress (no regressions): single nodeid, -k <no-match>, --collect-only, small subset, CP_TEST_BATCH_SIZE=1 all behave normally.

Files

  • tests/pytorch/attention/test_attention_with_cp.py — collect/dispatch/fetch infra, dry-run helpers, test bodies updated minimally.
  • tests/pytorch/attention/run_attention_with_cp.py_init_distributed, main() batch mode, atomic per-config flush, cross-rank aggregation, per-config group teardown, copy.deepcopy of model configs, transient env reset, split("=", 1).

Type of change

  • Code refactoring (test infrastructure; no production-code change)

Checklist

  • Contributing guidelines followed
  • Functionality complete
  • Code commented where non-obvious
  • Documentation (n/a — internal test infra)
  • No new warnings
  • Existing test suite serves as input + validation
  • Existing tests pass locally

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 6, 2026

Greptile Summary

This PR replaces per-test torchrun spawns with a session-scoped "batch dispatch" design: a dry-run collects kwargs for all CP attention tests into a module-level dict, groups configs by GPU count, chunks them into batches of up to CP_TEST_BATCH_SIZE, and launches one torchrun per chunk. NCCL is initialized once per batch instead of once per test, cutting wall-time roughly in half.

  • run_attention_with_cp.py gains a main() batch entry point, a shared _init_distributed() helper, per-config FP8GlobalStateManager.reset() + transient-env cleanup, atomic JSON result flushing, and cross-rank dist.all_reduce failure aggregation.
  • test_attention_with_cp.py gains the _cp_batch_results session fixture, the dry-run / collect / dispatch helpers (_run_or_fetch, _run_batch_once, _run_one_batch), a per-nodeid backend cache, and updates both test functions to call _run_or_fetch instead of run_distributed.

Confidence Score: 4/5

Safe to merge for the performance win; several robustness gaps in failure reporting and communicator cleanup were raised in earlier reviews and are not yet addressed.

The batch dispatch logic and result flushing are sound and the 2x speedup is well-validated. The changes that remain open from prior reviews — NCCL communicator leaks on mid-config exceptions, non-rank-0 traceback loss, and the misleading error guidance — are still present in the code and can make diagnosing failures harder in practice.

tests/pytorch/attention/run_attention_with_cp.py — the per-config communicator teardown path and the non-rank-0 error attribution logic warrant a second look before the first production failure surfaces.

Important Files Changed

Filename Overview
tests/pytorch/attention/run_attention_with_cp.py Added batch entry-point, _init_distributed, _run_single_config, per-config cleanup, cross-rank all_reduce result aggregation, and atomic JSON flushing. Several previously noted robustness gaps remain open (communicator leak on mid-config exception, misleading non-rank-0 error message).
tests/pytorch/attention/test_attention_with_cp.py Session fixture adds dry-run collection, torchrun batch dispatch, per-nodeid backend cache, and singleton retry. The _dry_run_item helper hard-assumes item.callspec exists, which silently misfires for any future non-parametrized test requesting _cp_batch_results.

Reviews (12): Last reviewed commit: "Merge branch 'main' into sudhakars/cp_te..." | Re-trigger Greptile

Comment thread tests/pytorch/attention/run_attention_with_cp.py
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from fa189b0 to 0e9fc1f Compare May 6, 2026 23:01
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch 4 times, most recently from 7802ec5 to c80df5d Compare May 7, 2026 13:57
Comment on lines +147 to +153
try:
argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path)
launch_err = None
try:
run_distributed(argv)
except AssertionError as exc:
launch_err = str(exc)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Only AssertionError caught from run_distributed

subprocess.run inside run_distributed can raise FileNotFoundError, PermissionError, or OSError for OS-level failures (missing executable, exhausted file descriptors, etc.). These propagate uncaught through _run_batch_once_run_one_batch_cp_batch_results. Because the fixture is session-scoped, one such exception causes every test that depends on _cp_batch_results to surface as a fixture ERROR rather than an individual test failure. In the original code, the same OS error would fail only the one test that triggered it. Widening the except to except (AssertionError, Exception) before reading the results file would preserve the per-batch isolation benefit.

…L init

Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.

This change introduces a session-scoped fixture that:
  1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
     a kwargs dict for the worker.
  2. Groups runnable configs by ``num_gpus`` and chunks them into batches
     of CP_TEST_BATCH_SIZE (default 16).
  3. Launches one torchrun per chunk; the worker initialises NCCL once
     and runs all configs in the chunk inside the same world.

Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.

Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 1db76b7 to 6355f62 Compare May 7, 2026 14:14
Comment on lines +762 to +765
dist.destroy_process_group(cp_comm_group)
if cp_comm_type == "a2a+p2p":
for sg in cp_comm_sub_groups:
dist.destroy_process_group(sg)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 NCCL communicator leak on exception mid-function

cp_comm_group is created at line 238 (and up to 4 sub-groups for a2a+p2p at lines 248-250) but the destroy calls are at the very bottom of the function with no try/finally. Any exception that fires in between — a CUDA OOM, a comparison mismatch, a BaseException from cuDNN — causes _run_single_config to catch it and return (False, traceback), while the communicators are never cleaned up.

In batch mode the problem compounds: with 16 configs per torchrun and any flaky configs, leaked communicators accumulate across the whole batch. NCCL's internal communicator table has a fixed limit (typically 128), so a few hundred batched configs with occasional failures can exhaust it and corrupt subsequent configs with opaque "NCCL error: invalid usage" rather than surfacing the original failure. Wrapping the body after group creation in a try/finally guarantees cleanup.

@sudhakarsingh27
Copy link
Copy Markdown
Collaborator Author

Tested this PR's batched CP changes on B200 (sm_103) and H100 (sm_90). The H100 run passed because most CP variants gate on sm_103 and skip on sm_90 — only 41 tests actually executed. The B200 run surfaced 143 failures that all share a single root cause — they're not independent bugs.

Failure pattern (every one of the 143 failures has this traceback referencing the same ProcessGroup instance):

torch.distributed.DistBackendError: NCCL communicator was aborted on rank 0.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "tests/pytorch/attention/run_attention_with_cp.py", line 800, in _run_single_config
    run_dpa_with_cp(**kwargs)
  File "tests/pytorch/attention/run_attention_with_cp.py", line 366, in run_dpa_with_cp
    with fp8_context:
  File "/usr/lib/python3.12/contextlib.py", line 144, in __exit__
    next(self.gen)
  File "transformer_engine/pytorch/quantization.py", line 905, in autocast
    FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph)
  File "transformer_engine/pytorch/quantization.py", line 649, in autocast_exit
    cls.reduce_and_update_fp8_tensors(forward=True)
  File "transformer_engine/pytorch/quantization.py", line 554, in reduce_and_update_fp8_tensors
    cls.reduce_tensor_across_group_op_max(contiguous_amax, group)
  File "transformer_engine/pytorch/quantization.py", line 518, in reduce_tensor_across_group_op_max
    torch.distributed.all_reduce(...)

ValueError: Process group ... is not initialized in the world group map.
    Please initialize the group first.

Cascade mechanism:

  1. The _cp_batch_results fixture pre-runs CP tests in batches over NCCL.
  2. One batch's NCCL communicator aborts on rank 0 (the original error isn't surfaced as the first failure).
  3. The world process group enters "not initialized" state and doesn't recover.
  4. Subsequent tests don't find their entry in _cp_batch_results and fall through to per-test execution via run_dpa_with_cp(...).
  5. The per-test path completes its computation, but FP8 cleanup (autocast_exitreduce_and_update_fp8_tensorsall_reduce) needs the world group → fails with "not initialized" → 143 identical AssertionErrors.

Performance observation: the CP session ran ~9s/test (1886s / 206 actually-executed tests). That's torchrun-startup-plus-a-bit per test — batching gave ~0 speedup once the first batch crashed and the fall-through path took over.

Suggested fixes:

  • When a batch's NCCL group aborts, don't let remaining tests fall through to per-test execution with a dead group. Either reset the world process group between batches, or mark all remaining tests in the crashed batch as ERROR upfront.
  • Surface the original NCCL abort as the first failure (with the failing test variant) instead of letting it manifest as 143 cleanup-failure cascades.

Results summary:

Hardware Passed Failed Skipped
H100 (sm_90) 41 0 10231
B200 (sm_103) 63 143 10066

Widen the except in _run_batch_once from AssertionError to Exception
so OS-level failures from subprocess.run (FileNotFoundError when the
worker script is missing, PermissionError, OSError when fds are
exhausted, etc.) are attributed to the batch they came from instead
of escaping the session-scoped _cp_batch_results fixture and
ERROR-ing every CP test in the run.

Addresses Greptile P1 review comment on PR 2965.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FP8GlobalStateManager retains quantizer registrations that reference
destroyed NCCL process groups, causing cascade failures when multiple
FP8 configs run in a single torchrun batch.  Reset the singleton
between configs to prevent this.

get_available_attention_backends is stateful — calling it during the
dry-run collect phase can produce different results than during the
execute phase, causing "skip divergence" where the batch collects
configs that should have been skipped.  Cache backend availability
per test node ID so the decision is consistent across phases.

Also: pass MASTER_PORT through to torchrun so parallel pytest
invocations on different GPU sets don't collide, and add [CP-BATCH]
progress logging to the batch infrastructure.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Restore run_dpa_with_cp as self-contained: detect whether dist is
already initialized and only init/destroy the global process group
when called standalone (legacy single-config mode). In batch mode
the function reuses the caller's process group and only tears down
per-config CP comm groups.

Extract _cached_backend_check helper so the backend-availability
cache is not scattered into both test bodies. Trim verbose docstrings
and inline comments down to single-line summaries.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the sudhakars/cp_test_batching_pr branch from 97fcd4c to e591e02 Compare May 8, 2026 21:02
Comment on lines +838 to +839
if not ok_aggregate and ok and err is None:
err = "Failed on a non-zero rank (see subprocess stderr for traceback)"
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Non-rank-0 failure traceback is swallowed and the error guidance is wrong

When a non-zero rank fails inside _run_single_config, its traceback is captured in that rank's local err variable but never transmitted to rank 0. The all_reduce propagates the ok=0 flag correctly, but rank 0 only records "Failed on a non-zero rank (see subprocess stderr for traceback)". That guidance is wrong: because _run_single_config catches the exception on rank 1, rank 1 exits cleanly and torchrun exits with code 0 — there is no traceback in subprocess stderr. A developer investigating the failure would find nothing there.

This is a regression from the original non-batched flow where rank 1's uncaught exception printed directly to torchrun's stderr and was captured by run_distributed. A minimal fix is to have the failing rank(s) print their traceback to sys.stderr before returning from _run_single_config, so it appears in torchrun's captured output even when the process exits cleanly.

sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 9, 2026
Port CP test batching from sudhakars/cp_test_batching_pr (PR NVIDIA#2965).
Groups parametrized configs into batches of CP_TEST_BATCH_SIZE (default
16) and runs each batch in a single torchrun invocation, amortizing the
~9s NCCL init overhead across configs instead of paying it per test.

This is a temporary commit to validate batching under CI on the
flash_attn_pad_bw_seqs branch — intended to be reverted after the run.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant