[torch.compile][PyTorch] Prepare linear for torch compile#2967
[torch.compile][PyTorch] Prepare linear for torch compile#2967pggPL wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Three small refactors that make the module easier to reason about and pave the way for the dataclass / saved-tensor refactors: - Add a TensorOrQuantized type alias (Union[Tensor, QuantizedTensorStorage]) used pervasively in helper signatures. - Hoist the conditional bias argument into a local linear_bias_tensor variable instead of an inline expression at the linear_fn() call site. - Only forward self.wgrad_store into the autograd Function when it is actually active (delay_wgrad_compute() is True); pass None otherwise so the autograd graph does not carry an unused Python object. Pure rename / hoisting; no behavioural change. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Replace the loosely typed ``non_tensor_args`` tuple and the ad-hoc ``ctx.<attr>`` plumbing with two dataclasses, ``LinearFwdArgs`` and ``LinearBwdArgs``, that act as the single argument to every helper in the forward/backward pipeline. What changes: * ``LinearFwdArgs`` carries the (positional) tensors ``weight``, ``inp`` and ``bias`` plus all quantizers, ``requires_grad`` flags, the cached ``weight_workspace`` and every former ``non_tensor_args`` knob. ``_Linear.forward`` still takes ``weight/inp/bias`` as positional Tensor inputs so autograd tracks them, then immediately re-attaches them to ``fwd_args`` so every downstream helper has a single-argument signature. * ``LinearBwdArgs`` mirrors that on the backward side: it owns the saved tensors (``inputmat``, ``weight_fp8``, ``saved_weight``, ``bias``), the per-call quantizers, every flag previously stored directly on ``ctx`` and a ``setup_saved_tensors(saved_tensors, tensor_objects)`` helper that rehydrates the saved-tensor fields. * ``ctx.backward_objects = bwd_args`` is now the single attribute the autograd context needs (besides ``saved_tensors``/``tensor_objects``). * ``weight_workspace`` is no longer a positional Tensor arg of the autograd Function; it is read from ``fwd_args.weight_workspace`` and the freshly produced workspace is returned alongside ``out`` so the module can refresh its cache without autograd tracking the cache. * ``prepare_for_saving`` now lives at the autograd boundary in ``_Linear.forward``; ``_linear_setup_ctx`` only returns the merged list of tensors that should be saved. * ``grad_output_preprocess`` is invoked with ``bwd_args`` directly (it is duck-typed on the same attribute names) so backward never reaches into ``ctx.<attr>`` for non-tensor state. Behaviour preserved (verified numerically against ``torch.nn.Linear`` and on FP8 + workspace-cache paths). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
When ``saved_inputmat is inp``, ``wt_save is weight`` or ``bias`` is the exact bias passed in, there is no point asking ``prepare_for_saving`` to serialize the same Python object twice. Make ``_linear_forward_impl`` emit ``None`` in those slots (and a parallel ``saved_tensor_aliases`` tuple in ``ctx_attrs`` describing which slot points where), and have ``_linear_setup_ctx`` rebuild the tuple with the original references before handing it to ``prepare_for_saving``. Saves a Python ref per alias in eager and, more importantly, keeps the forward helper from "returning" a tensor that aliases its own inputs -- a pattern ``torch.compile`` would otherwise need to reason about when the helper is wrapped in an opaque op. Numerically equivalent (validated against ``torch.nn.Linear`` and on a multi-iteration FP8 path with workspace caching). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
Follow-up cleanups on top of the dataclass refactor:
* Sort ``LinearFwdArgs`` / ``LinearBwdArgs`` fields into labelled groups
(tensors, requires_grad flags, quantizers, dtype/numerical config,
parallelism, userbuffers, FSDP, wgrad scheduling, misc) and mirror that
ordering in their construction sites.
* Add ``slots=True`` to both dataclasses so typos in
``fwd_args.X`` / ``bwd_args.X`` raise ``AttributeError`` immediately
instead of silently creating a new attribute.
* Inline single-use ``args.X`` aliases in ``_linear_forward_impl``
(``weight_workspace``, ``fp8_calibration``, ``tp_size``,
``tensor_parallel``, ``cache_weight``, ``skip_fp8_weight_update``,
``custom``, ``backward_input_needs_gather``) so the prelude only keeps
aliases that are actually reused.
* Shrink ``ctx_attrs`` to ``{fsdp_shapes, saved_tensor_aliases}``:
``weight_quantizer`` is re-derived in ``_linear_setup_ctx`` from
``fwd_args.weight`` (matching the resolution done in forward),
``is_fsdp2`` already lives on ``fwd_args``, and ``owns_input`` is
equivalent to ``saved_tensor_aliases[0] != "inp"``.
* Replace ``setup_saved_tensors(saved_tensors, tensor_objects)`` with
``setup_saved_tensors(ctx)`` backed by ``restore_from_func_ctx``,
matching ``layernorm_mlp`` / ``layernorm_linear`` /
``grouped_linear`` and dropping the manual
``ctx.tensor_objects = None`` cleanup.
Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
After packing the Linear backward state into ``LinearBwdArgs`` the attributes the test was reading (``backward_override``, ``fp8``, ``grad_output_quantizer``, ``reduce_and_update_bwd_fp8_tensors``) no longer live directly on ``grad_fn``. Read them from ``grad_fn.backward_objects`` when present, falling back to ``grad_fn`` for the linear-like modules that have not been refactored yet (``layernorm_linear``, ``ops_linear``). Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
for more information, see https://pre-commit.ci
Restore the one-line class docstrings dropped during the field reorganization so pylint stops warning about C0115. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
Greptile SummaryThis PR refactors
Confidence Score: 4/5Safe to merge for the common single-backward case; callers using retain_graph=True will hit a new AttributeError crash on the second backward call that was not present before this refactor. The mechanical ctx-to-dataclass translation is faithful and the return-tuple arity change (11 to 4 gradients) correctly matches the new 4-argument _Linear.forward. The one concrete regression is ctx.backward_objects = None after backward: the comment acknowledges retain_graph but nulling the outer pointer rather than individual fields means any second backward call immediately dereferences None and crashes — a path that worked in the old code. transformer_engine/pytorch/module/linear.py — specifically the ctx.backward_objects = None teardown in _Linear.backward and its interaction with retain_graph=True. Important Files Changed
Sequence DiagramsequenceDiagram
participant LF as Linear.forward
participant App as _Linear.apply
participant Fwd as _linear_forward_impl
participant SetCtx as _linear_setup_ctx
participant Bwd as _linear_backward
LF->>LF: Build LinearFwdArgs (all config + tensors)
LF->>App: apply(weight, inp, bias, fwd_args)
App->>Fwd: _linear_forward_impl(fwd_args)
Fwd-->>App: out, new_ws, tensors_to_save_from_forward, None, ctx_attrs
App->>App: Create LinearBwdArgs()
App->>SetCtx: _linear_setup_ctx(bwd_args, fwd_args, out, ctx_attrs, tensors)
SetCtx->>SetCtx: Populate bwd_args fields from fwd_args
SetCtx->>SetCtx: Reconstruct aliased tensors (inp/weight/bias)
SetCtx-->>App: tensors_to_save_from_setup
App->>App: prepare_for_saving + ctx.save_for_backward
App->>App: "ctx.backward_objects = bwd_args"
App-->>LF: out, new_weight_workspace
Note over App,Bwd: Later - gradient computation
App->>App: "bwd_args = ctx.backward_objects"
App->>App: bwd_args.setup_saved_tensors(ctx)
App->>Bwd: _linear_backward(bwd_args)
Bwd-->>App: (wgrad, dgrad, grad_bias)
App->>App: "ctx.backward_objects = None breaks retain_graph"
App-->>LF: (wgrad, dgrad, grad_bias, None)
|
Saved tensors, quantizers, weakrefs and main_grad closures referenced from LinearBwdArgs survived until ctx GC, extending peak GPU memory under retain_graph=True. Null out ctx.backward_objects right after _linear_backward so they are released as soon as backward returns. Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
|
/te-ci pytorch L1 |
| # Drop all references held by bwd_args (saved tensors, quantizers, weakrefs, | ||
| # main_grad closure) so they don't outlive backward via ctx under retain_graph. | ||
| ctx.backward_objects = None | ||
| del bwd_args |
There was a problem hiding this comment.
If we destroy the cached state, we should also mark backward with function.once_differentiable.
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class LinearFwdArgs: |
There was a problem hiding this comment.
This design is nicer than exposing a long list of positional args or passing in a tuple of args, but it is less nice than exposing kwargs. However, it has some advantages for this case:
torch.compileinfrastructure will only need to handle one non-tensor arg to the autograd function.- Autograd functions do some processing on each arg, so minimizing the number of non-tensor args reduces CPU overhead.
_Linear.forwardcalls_linear_forward_impland it doesn't need to be aware of the impl specifics.
|
|
||
|
|
||
| @dataclass(slots=True) | ||
| class LinearBwdArgs: |
There was a problem hiding this comment.
In typical usage, LinearBwdArgs is somewhat redundant with the autograd context class. But we need it so that the forward-context-saving and backward are less entangled with autograd, which will allow code reuse when we implement an alternate torch.compile code path.
Description
Refactor of
transformer_engine/pytorch/module/linear.pyto lift theLinearmodule into a shape that can be wrapped in atorch.library.custom_op(and matching backward op) in a follow-up PR.Type of change
Changes
LinearFwdArgs/LinearBwdArgsdataclasses.
_linear_forward_impl,_linear_setup_ctx,_linear_backwardand the_Linear.{forward,backward}autogradmethods all take a single structured argument instead of 25+
positional ones. A custom op requires a fully-declared signature on
both sides; the previous pattern of writing arbitrary
ctx.something = ...attributes scattered throughout forward madeit impossible to tell from the call site what state backward
actually consumes. The dataclasses make the read/write contract
explicit and grep-able. Concretely, things that used to be re-queried
from tensor objects (
input.requires_grad,weight.requires_grad,bias.requires_grad) are now captured up front asinput_requires_grad/weight_requires_grad/bias_requires_gradand consumed asrequires_dgrad/requires_wgradin backward — backward no longer has to assume thePython tensor objects survive the op boundary.
prepare_for_saving/ctx.save_for_backwardto the autogradboundary.
_linear_forward_implreturns the raw tensors itproduced (along with
tensors_to_save_from_forwardaliases) and_linear_setup_ctxreturns the raw merged tensor list it wantssaved;
_Linear.forwardis the only place that actually callsprepare_for_saving(*tensors_to_save_from_setup)andctx.save_for_backward(...). This shape fell out of the compile-pathexperiments: under
torch.library.register_autograd, thesetup_contextcallback is the only legal place to callsave_for_backward, and the helper has to hand back tensors ratherthan mutate the autograd ctx itself. Same contract is now used in
eager so a single helper serves both modes.
that would alias
inp/weight/biasare emitted asNoneandreconstructed in
_linear_setup_ctxfrom the original refs. Anopaque custom op cannot return aliases of its inputs (the tracer
has no way to reason about the aliasing).
ctx_attrsblob plumbed from forward to backwardsetup. Anything that can be re-derived from
LinearFwdArgs(
weight_quantizer,is_fsdp2,owns_input) is recomputed in_linear_setup_ctx. The compile path needs the fake forward implto return a structurally identical
ctx_attrs, so a smaller surfaceis a smaller cross-impl contract.
Checklist: