[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965
[PyTorch] Batch CP attention tests in single torchrun to amortize NCC…#2965sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR replaces per-test
Confidence Score: 4/5Safe to merge for the performance win; several robustness gaps in failure reporting and communicator cleanup were raised in earlier reviews and are not yet addressed. The batch dispatch logic and result flushing are sound and the 2x speedup is well-validated. The changes that remain open from prior reviews — NCCL communicator leaks on mid-config exceptions, non-rank-0 traceback loss, and the misleading error guidance — are still present in the code and can make diagnosing failures harder in practice. tests/pytorch/attention/run_attention_with_cp.py — the per-config communicator teardown path and the non-rank-0 error attribution logic warrant a second look before the first production failure surfaces. Important Files Changed
Reviews (12): Last reviewed commit: "Merge branch 'main' into sudhakars/cp_te..." | Re-trigger Greptile |
fa189b0 to
0e9fc1f
Compare
7802ec5 to
c80df5d
Compare
| try: | ||
| argv = get_bash_arguments(num_gpus_per_node=num_gpus, batch_config_json=batch_path) | ||
| launch_err = None | ||
| try: | ||
| run_distributed(argv) | ||
| except AssertionError as exc: | ||
| launch_err = str(exc) |
There was a problem hiding this comment.
Only
AssertionError caught from run_distributed
subprocess.run inside run_distributed can raise FileNotFoundError, PermissionError, or OSError for OS-level failures (missing executable, exhausted file descriptors, etc.). These propagate uncaught through _run_batch_once → _run_one_batch → _cp_batch_results. Because the fixture is session-scoped, one such exception causes every test that depends on _cp_batch_results to surface as a fixture ERROR rather than an individual test failure. In the original code, the same OS error would fail only the one test that triggered it. Widening the except to except (AssertionError, Exception) before reading the results file would preserve the per-batch isolation benefit.
…L init
Each parametrized CP test currently spawns its own torchrun process and
pays 5-15s of NCCL init/destroy. With ~650-800 collected tests this
adds up to 1.5-3 hours of pure setup overhead.
This change introduces a session-scoped fixture that:
1. Calls per-test ``_prepare_*`` helpers to get either a skip reason or
a kwargs dict for the worker.
2. Groups runnable configs by ``num_gpus`` and chunks them into batches
of CP_TEST_BATCH_SIZE (default 16).
3. Launches one torchrun per chunk; the worker initialises NCCL once
and runs all configs in the chunk inside the same world.
Per-config results are flushed to JSON after every config so a crash
mid-batch still leaves earlier results intact. Set CP_TEST_BATCH_SIZE=1
to bisect a failing batch.
Also includes a small bugfix in dot_product_attention/utils.py: the
deterministic-FA3 disable condition was firing for any head_dim_qk > 128
(including inference); restrict it to is_training and large head dims.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
1db76b7 to
6355f62
Compare
for more information, see https://pre-commit.ci
| dist.destroy_process_group(cp_comm_group) | ||
| if cp_comm_type == "a2a+p2p": | ||
| for sg in cp_comm_sub_groups: | ||
| dist.destroy_process_group(sg) |
There was a problem hiding this comment.
NCCL communicator leak on exception mid-function
cp_comm_group is created at line 238 (and up to 4 sub-groups for a2a+p2p at lines 248-250) but the destroy calls are at the very bottom of the function with no try/finally. Any exception that fires in between — a CUDA OOM, a comparison mismatch, a BaseException from cuDNN — causes _run_single_config to catch it and return (False, traceback), while the communicators are never cleaned up.
In batch mode the problem compounds: with 16 configs per torchrun and any flaky configs, leaked communicators accumulate across the whole batch. NCCL's internal communicator table has a fixed limit (typically 128), so a few hundred batched configs with occasional failures can exhaust it and corrupt subsequent configs with opaque "NCCL error: invalid usage" rather than surfacing the original failure. Wrapping the body after group creation in a try/finally guarantees cleanup.
|
Tested this PR's batched CP changes on B200 (sm_103) and H100 (sm_90). The H100 run passed because most CP variants gate on sm_103 and skip on sm_90 — only 41 tests actually executed. The B200 run surfaced 143 failures that all share a single root cause — they're not independent bugs. Failure pattern (every one of the 143 failures has this traceback referencing the same Cascade mechanism:
Performance observation: the CP session ran ~9s/test (1886s / 206 actually-executed tests). That's torchrun-startup-plus-a-bit per test — batching gave ~0 speedup once the first batch crashed and the fall-through path took over. Suggested fixes:
Results summary:
|
Widen the except in _run_batch_once from AssertionError to Exception so OS-level failures from subprocess.run (FileNotFoundError when the worker script is missing, PermissionError, OSError when fds are exhausted, etc.) are attributed to the batch they came from instead of escaping the session-scoped _cp_batch_results fixture and ERROR-ing every CP test in the run. Addresses Greptile P1 review comment on PR 2965. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
FP8GlobalStateManager retains quantizer registrations that reference destroyed NCCL process groups, causing cascade failures when multiple FP8 configs run in a single torchrun batch. Reset the singleton between configs to prevent this. get_available_attention_backends is stateful — calling it during the dry-run collect phase can produce different results than during the execute phase, causing "skip divergence" where the batch collects configs that should have been skipped. Cache backend availability per test node ID so the decision is consistent across phases. Also: pass MASTER_PORT through to torchrun so parallel pytest invocations on different GPU sets don't collide, and add [CP-BATCH] progress logging to the batch infrastructure. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Restore run_dpa_with_cp as self-contained: detect whether dist is already initialized and only init/destroy the global process group when called standalone (legacy single-config mode). In batch mode the function reuses the caller's process group and only tears down per-config CP comm groups. Extract _cached_backend_check helper so the backend-availability cache is not scattered into both test bodies. Trim verbose docstrings and inline comments down to single-line summaries. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
97fcd4c to
e591e02
Compare
for more information, see https://pre-commit.ci
…kars/cp_test_batching_pr
| if not ok_aggregate and ok and err is None: | ||
| err = "Failed on a non-zero rank (see subprocess stderr for traceback)" |
There was a problem hiding this comment.
Non-rank-0 failure traceback is swallowed and the error guidance is wrong
When a non-zero rank fails inside _run_single_config, its traceback is captured in that rank's local err variable but never transmitted to rank 0. The all_reduce propagates the ok=0 flag correctly, but rank 0 only records "Failed on a non-zero rank (see subprocess stderr for traceback)". That guidance is wrong: because _run_single_config catches the exception on rank 1, rank 1 exits cleanly and torchrun exits with code 0 — there is no traceback in subprocess stderr. A developer investigating the failure would find nothing there.
This is a regression from the original non-batched flow where rank 1's uncaught exception printed directly to torchrun's stderr and was captured by run_distributed. A minimal fix is to have the failing rank(s) print their traceback to sys.stderr before returning from _run_single_config, so it appears in torchrun's captured output even when the process exits cleanly.
Port CP test batching from sudhakars/cp_test_batching_pr (PR NVIDIA#2965). Groups parametrized configs into batches of CP_TEST_BATCH_SIZE (default 16) and runs each batch in a single torchrun invocation, amortizing the ~9s NCCL init overhead across configs instead of paying it per test. This is a temporary commit to validate batching under CI on the flash_attn_pad_bw_seqs branch — intended to be reverted after the run. Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Design
Problem
8×H100,
test_essential=True, 38 runnable CP attention configs:torchrun.We need that overhead amortised, without changing how tests are written or how skips report.
Approach
A session-scoped fixture (
_cp_batch_results) does two passes:_cp_batch_results, call its test function directly with a stubbedrequest. The body executes its inlinepytest.skip(...)checks normally; if any fires, the item is dropped from the batch. Otherwise the body's final call to_run_or_fetch(...)records its kwargs in a module-level dict instead of launching a subprocess.num_gpus_per_node, chunk into batches ofCP_TEST_BATCH_SIZE(default 16), launch onetorchrunper chunk. Worker (run_attention_with_cp.py) inits NCCL once, loops over configs, atomically flushes per-config results to<batch>.results.json. When pytest later runs each test for real, the body re-evaluates skips and_run_or_fetchlooks up the recorded result.How dry-run works
_dry_run_itemcalls the underlying function with the same parametrize values pytest would have passed:This bypasses pytest's runner entirely — no fixture setup hooks, no plugin reporters, no captured-stdout machinery.
_run_or_fetchchecks a module-level_COLLECT_MODEflag:In collect mode it's a recorder, in execute mode it's a result-fetcher. The test body doesn't know which mode it's in — it just calls one helper at the end.
Stubs and skip handling
request_DummyRequest(nodeid)— onlyrequest.node.nodeid_run_or_fetchonly readsnodeid; test body never touchesrequest._cp_batch_results{}(empty dict)_run_or_fetchreturns early in collect mode, never inspectsbatch_results.Inline
pytest.skip("reason")raisespytest.skip.Exception. The dry-run loop catches per-item, drops the item from the batch, and moves on. In execute mode the same line raises again; pytest reports SKIPPED with the same reason.@pytest.mark.skipand@pytest.mark.skipif(<bool_condition>)markers don't fire when callingitem.function(...)directly._item_static_skip(item)walksitem.iter_markers("skip"|"skipif")and readsmarker.args[0](the condition) before the dry-run, dropping items the markers would otherwise skip.Cost of running the body twice
For each item the body runs once during dry-run and once during execute. Skip checks are pure Python; the only non-trivial work is
get_available_attention_backends, cached per nodeid via_BACKEND_CACHEso the second call is a dict hit. Measured on fulltest_essential=True(10272 collected items, 38 runnable): 530 cache lookups, 0.03 s total.End-to-end pytest overhead (dry-run + collection, with
torchrunstubbed): ~14 s wall, of which ~6.6 s is module-import startup, ~3.2 s is pytest per-item setup, 0.2 s is the batching infra itself. Negligible vs the GPU work it dispatches.Performance
8×H100,
test_essential=True(38 runnable configs: 34 × 2-GPU + 4 × 4-GPU). In unbatched mode each config is its owntorchrun. In batched mode, configs sharing the samenum_gpus_per_nodeare grouped into onetorchrunof up toCP_TEST_BATCH_SIZEconfigs.B=16B=32B=50Where the 2× comes from
torchrunspawns. Each saved spawn cuts ~12 s of startup (Python imports + NCCL global init/teardown) — measured directly from the wall-time deltaB=50 → B=16(37 s saved across 2 fewer spawns).The spawn savings dominate.
Picking
CP_TEST_BATCH_SIZEFor these 38 configs (34 × 2-GPU + 4 × 4-GPU):
B=16 → B=32: −27 s (one fewer 2-GPU spawn).B=32 → B=50: −11 s (one more, but the merged batch is now large enough that bookkeeping eats into the savings).16and32are ballparks for this matrix — onceBexceeds the largest GPU-group size (here, 34), all configs in that group already share atorchrunand further increases do nothing. With a larger config matrix (e.g. fulltest_essential=False≈ 348 runnable), the same logic impliesBshould scale up too: pick it so the largest GPU-group has only a small number oftorchrunspawns, but not so large that a single batch becomes long enough that a worker crash loses too much progress.Knobs
CP_TEST_BATCH_SIZE=Ntorchrun. Default 16. Set 1 to bisect.CP_TEST_BATCH_RETRY=0Adding a batched test
@pytest.mark.parametrizestack + inlinepytest.skip(...)checks.request, _cp_batch_resultsto the function signature.run_distributed(get_bash_arguments(...))with_run_or_fetch(request, _cp_batch_results, num_gpus_per_node=N, ...)(kwargs become the worker'srun_dpa_with_cp(**kwargs)arguments).That's the entire wiring.
Failure semantics
pytest.skip(...)fires_run_or_fetch).@pytest.mark.skip(if)marker firestorchrun).dist.all_reduce(ok, op=MIN).Mitigations for shared-process state
Configs in a batch share one Python process and one NCCL world, so anything that needs a clean per-test starting point is reset explicitly:
cp_comm_group,a2a+p2psub-groups)._TRANSIENT_ENV_KEYSbetween configs (NVTE_FLASH_ATTN,NVTE_FUSED_ATTN,NVTE_FP8_DPA_BWD,NVTE_DPA_FP8CS_O_in_F16,NVTE_ALLOW_NONDETERMINISTIC_ALGO).torch.cuda.empty_cache(),dist.barrier()between configs.1234) at start of each config.copy.deepcopy(model_configs_*[model])in the worker (THD path mutatesattn_mask_type).tmp + os.replace): a partial JSON is never visible to the reader.dist.all_reduce(ok, op=MIN)after each config so any rank's failure flipsokto False.CP_TEST_BATCH_RETRY=0.arg.split("=", 1)so kwarg values containing=(paths) survive.Edge cases
requestAPI surface during dry-run. Onlyrequest.node.nodeidis provided. A future test that usesrequest.config.getoption(...)orrequest.getfixturevalue(...)wouldAttributeErrorduring dry-run. The fixture catchesBaseExceptionso the same error fires in execute mode where pytest's realrequestis available.@pytest.mark.skipif(condition_evaluated_at_runtime). Askipifwhose condition becomes True only at execute time would not be detected by_item_static_skip. The condition still fires correctly in execute mode; we'd just have wasted one batch slot for it.get_available_attention_backendsnon-determinism. If this returns different values between dry-run and execute (driver state changes), a config queued by collect could skip in execute. Harmless:_run_or_fetchis never reached, the unused batch result is garbage-collected.item.function,item.callspec.params, andpytest.skip.Exception. Stable in pytest 7+/8+. If they shift,_dry_run_itemis a 3-line shim to update.Validation
8×H100,
test_essential=True: 38 passed / 10234 skipped / 0 unrelated failures.Stress (no regressions): single nodeid,
-k <no-match>,--collect-only, small subset,CP_TEST_BATCH_SIZE=1all behave normally.Files
tests/pytorch/attention/test_attention_with_cp.py— collect/dispatch/fetch infra, dry-run helpers, test bodies updated minimally.tests/pytorch/attention/run_attention_with_cp.py—_init_distributed,main()batch mode, atomic per-config flush, cross-rank aggregation, per-config group teardown,copy.deepcopyof model configs, transient env reset,split("=", 1).Type of change
Checklist