Skip to content

[torch.compile][PyTorch] Prepare linear for torch compile#2967

Open
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:prepare_linear_for_torch_compile
Open

[torch.compile][PyTorch] Prepare linear for torch compile#2967
pggPL wants to merge 11 commits intoNVIDIA:mainfrom
pggPL:prepare_linear_for_torch_compile

Conversation

@pggPL
Copy link
Copy Markdown
Collaborator

@pggPL pggPL commented May 7, 2026

Description

Refactor of transformer_engine/pytorch/module/linear.py to lift the
Linear module into a shape that can be wrapped in a
torch.library.custom_op (and matching backward op) in a follow-up PR.

Type of change

  • Code refactoring

Changes

  • Pack forward/backward state into LinearFwdArgs / LinearBwdArgs
    dataclasses.
    _linear_forward_impl, _linear_setup_ctx,
    _linear_backward and the _Linear.{forward,backward} autograd
    methods all take a single structured argument instead of 25+
    positional ones. A custom op requires a fully-declared signature on
    both sides; the previous pattern of writing arbitrary
    ctx.something = ... attributes scattered throughout forward made
    it impossible to tell from the call site what state backward
    actually consumes. The dataclasses make the read/write contract
    explicit and grep-able. Concretely, things that used to be re-queried
    from tensor objects (input.requires_grad, weight.requires_grad,
    bias.requires_grad) are now captured up front as
    input_requires_grad / weight_requires_grad /
    bias_requires_grad and consumed as requires_dgrad /
    requires_wgrad in backward — backward no longer has to assume the
    Python tensor objects survive the op boundary.
  • Move prepare_for_saving / ctx.save_for_backward to the autograd
    boundary.
    _linear_forward_impl returns the raw tensors it
    produced (along with tensors_to_save_from_forward aliases) and
    _linear_setup_ctx returns the raw merged tensor list it wants
    saved; _Linear.forward is the only place that actually calls
    prepare_for_saving(*tensors_to_save_from_setup) and
    ctx.save_for_backward(...). This shape fell out of the compile-path
    experiments: under torch.library.register_autograd, the
    setup_context callback is the only legal place to call
    save_for_backward, and the helper has to hand back tensors rather
    than mutate the autograd ctx itself. Same contract is now used in
    eager so a single helper serves both modes.
  • Deduplicate saved tensors that alias forward inputs. Save-slots
    that would alias inp / weight / bias are emitted as None and
    reconstructed in _linear_setup_ctx from the original refs. An
    opaque custom op cannot return aliases of its inputs (the tracer
    has no way to reason about the aliasing).
  • Minimize the ctx_attrs blob plumbed from forward to backward
    setup.
    Anything that can be re-derived from LinearFwdArgs
    (weight_quantizer, is_fsdp2, owns_input) is recomputed in
    _linear_setup_ctx. The compile path needs the fake forward impl
    to return a structurally identical ctx_attrs, so a smaller surface
    is a smaller cross-impl contract.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

pggPL and others added 10 commits May 8, 2026 00:17
Three small refactors that make the module easier to reason about
and pave the way for the dataclass / saved-tensor refactors:

- Add a TensorOrQuantized type alias (Union[Tensor, QuantizedTensorStorage])
  used pervasively in helper signatures.
- Hoist the conditional bias argument into a local linear_bias_tensor
  variable instead of an inline expression at the linear_fn() call site.
- Only forward self.wgrad_store into the autograd Function when it is
  actually active (delay_wgrad_compute() is True); pass None otherwise so
  the autograd graph does not carry an unused Python object.

Pure rename / hoisting; no behavioural change.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Replace the loosely typed ``non_tensor_args`` tuple and the ad-hoc
``ctx.<attr>`` plumbing with two dataclasses, ``LinearFwdArgs`` and
``LinearBwdArgs``, that act as the single argument to every helper
in the forward/backward pipeline.

What changes:

* ``LinearFwdArgs`` carries the (positional) tensors ``weight``, ``inp``
  and ``bias`` plus all quantizers, ``requires_grad`` flags, the cached
  ``weight_workspace`` and every former ``non_tensor_args`` knob.
  ``_Linear.forward`` still takes ``weight/inp/bias`` as positional
  Tensor inputs so autograd tracks them, then immediately re-attaches
  them to ``fwd_args`` so every downstream helper has a single-argument
  signature.
* ``LinearBwdArgs`` mirrors that on the backward side: it owns the
  saved tensors (``inputmat``, ``weight_fp8``, ``saved_weight``,
  ``bias``), the per-call quantizers, every flag previously stored
  directly on ``ctx`` and a ``setup_saved_tensors(saved_tensors,
  tensor_objects)`` helper that rehydrates the saved-tensor fields.
* ``ctx.backward_objects = bwd_args`` is now the single attribute the
  autograd context needs (besides ``saved_tensors``/``tensor_objects``).
* ``weight_workspace`` is no longer a positional Tensor arg of the
  autograd Function; it is read from ``fwd_args.weight_workspace`` and
  the freshly produced workspace is returned alongside ``out`` so the
  module can refresh its cache without autograd tracking the cache.
* ``prepare_for_saving`` now lives at the autograd boundary in
  ``_Linear.forward``; ``_linear_setup_ctx`` only returns the merged
  list of tensors that should be saved.
* ``grad_output_preprocess`` is invoked with ``bwd_args`` directly
  (it is duck-typed on the same attribute names) so backward never
  reaches into ``ctx.<attr>`` for non-tensor state.

Behaviour preserved (verified numerically against ``torch.nn.Linear``
and on FP8 + workspace-cache paths).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
When ``saved_inputmat is inp``, ``wt_save is weight`` or ``bias`` is the
exact bias passed in, there is no point asking ``prepare_for_saving`` to
serialize the same Python object twice. Make ``_linear_forward_impl``
emit ``None`` in those slots (and a parallel ``saved_tensor_aliases``
tuple in ``ctx_attrs`` describing which slot points where), and have
``_linear_setup_ctx`` rebuild the tuple with the original references
before handing it to ``prepare_for_saving``.

Saves a Python ref per alias in eager and, more importantly, keeps the
forward helper from "returning" a tensor that aliases its own inputs --
a pattern ``torch.compile`` would otherwise need to reason about when
the helper is wrapped in an opaque op.

Numerically equivalent (validated against ``torch.nn.Linear`` and on a
multi-iteration FP8 path with workspace caching).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Follow-up cleanups on top of the dataclass refactor:

* Sort ``LinearFwdArgs`` / ``LinearBwdArgs`` fields into labelled groups
  (tensors, requires_grad flags, quantizers, dtype/numerical config,
  parallelism, userbuffers, FSDP, wgrad scheduling, misc) and mirror that
  ordering in their construction sites.
* Add ``slots=True`` to both dataclasses so typos in
  ``fwd_args.X`` / ``bwd_args.X`` raise ``AttributeError`` immediately
  instead of silently creating a new attribute.
* Inline single-use ``args.X`` aliases in ``_linear_forward_impl``
  (``weight_workspace``, ``fp8_calibration``, ``tp_size``,
  ``tensor_parallel``, ``cache_weight``, ``skip_fp8_weight_update``,
  ``custom``, ``backward_input_needs_gather``) so the prelude only keeps
  aliases that are actually reused.
* Shrink ``ctx_attrs`` to ``{fsdp_shapes, saved_tensor_aliases}``:
  ``weight_quantizer`` is re-derived in ``_linear_setup_ctx`` from
  ``fwd_args.weight`` (matching the resolution done in forward),
  ``is_fsdp2`` already lives on ``fwd_args``, and ``owns_input`` is
  equivalent to ``saved_tensor_aliases[0] != "inp"``.
* Replace ``setup_saved_tensors(saved_tensors, tensor_objects)`` with
  ``setup_saved_tensors(ctx)`` backed by ``restore_from_func_ctx``,
  matching ``layernorm_mlp`` / ``layernorm_linear`` /
  ``grouped_linear`` and dropping the manual
  ``ctx.tensor_objects = None`` cleanup.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
After packing the Linear backward state into ``LinearBwdArgs`` the
attributes the test was reading (``backward_override``, ``fp8``,
``grad_output_quantizer``, ``reduce_and_update_bwd_fp8_tensors``) no
longer live directly on ``grad_fn``. Read them from
``grad_fn.backward_objects`` when present, falling back to ``grad_fn``
for the linear-like modules that have not been refactored yet
(``layernorm_linear``, ``ops_linear``).

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
Restore the one-line class docstrings dropped during the field
reorganization so pylint stops warning about C0115.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: Cursor <cursoragent@cursor.com>
@pggPL pggPL marked this pull request as ready for review May 8, 2026 11:42
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented May 8, 2026

Greptile Summary

This PR refactors transformer_engine/pytorch/module/linear.py to pack the 25+ scattered ctx.something assignments into two typed dataclasses — LinearFwdArgs and LinearBwdArgs — and moves prepare_for_saving/ctx.save_for_backward to the autograd boundary, laying the groundwork for wrapping the linear op in a torch.library.custom_op.

  • LinearFwdArgs / LinearBwdArgs dataclasses replace the old positional non_tensor_args tuple and scattered ctx.* assignments; forward/backward helpers are each reduced to a single structured argument.
  • Alias deduplication stores save slots that alias forward inputs as None, reconstructing them from ctx_attrs[\"saved_tensor_aliases\"] in _linear_setup_ctx, satisfying the torch.library.custom_op no-alias constraint.
  • ctx.backward_objects = None after the first backward call drops tensor/closure references but causes AttributeError on any second backward call — a regression vs. the old flat-ctx layout where attributes persisted across calls.

Confidence Score: 4/5

Safe to merge for the common single-backward case; callers using retain_graph=True will hit a new AttributeError crash on the second backward call that was not present before this refactor.

The mechanical ctx-to-dataclass translation is faithful and the return-tuple arity change (11 to 4 gradients) correctly matches the new 4-argument _Linear.forward. The one concrete regression is ctx.backward_objects = None after backward: the comment acknowledges retain_graph but nulling the outer pointer rather than individual fields means any second backward call immediately dereferences None and crashes — a path that worked in the old code.

transformer_engine/pytorch/module/linear.py — specifically the ctx.backward_objects = None teardown in _Linear.backward and its interaction with retain_graph=True.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/linear.py Large mechanical refactor packing 25+ scattered ctx attributes into LinearFwdArgs/LinearBwdArgs dataclasses; generally clean, but ctx.backward_objects = None after backward introduces a crash on any second backward call (retain_graph=True), which was previously safe.
tests/pytorch/test_backward_override.py Backward-override test helper updated to look for state in grad_fn.backward_objects (new LinearBwdArgs location) with a fallback to the old flat grad_fn attribute access; change is correct and backward-compatible with other modules.

Sequence Diagram

sequenceDiagram
    participant LF as Linear.forward
    participant App as _Linear.apply
    participant Fwd as _linear_forward_impl
    participant SetCtx as _linear_setup_ctx
    participant Bwd as _linear_backward

    LF->>LF: Build LinearFwdArgs (all config + tensors)
    LF->>App: apply(weight, inp, bias, fwd_args)
    App->>Fwd: _linear_forward_impl(fwd_args)
    Fwd-->>App: out, new_ws, tensors_to_save_from_forward, None, ctx_attrs
    App->>App: Create LinearBwdArgs()
    App->>SetCtx: _linear_setup_ctx(bwd_args, fwd_args, out, ctx_attrs, tensors)
    SetCtx->>SetCtx: Populate bwd_args fields from fwd_args
    SetCtx->>SetCtx: Reconstruct aliased tensors (inp/weight/bias)
    SetCtx-->>App: tensors_to_save_from_setup
    App->>App: prepare_for_saving + ctx.save_for_backward
    App->>App: "ctx.backward_objects = bwd_args"
    App-->>LF: out, new_weight_workspace

    Note over App,Bwd: Later - gradient computation
    App->>App: "bwd_args = ctx.backward_objects"
    App->>App: bwd_args.setup_saved_tensors(ctx)
    App->>Bwd: _linear_backward(bwd_args)
    Bwd-->>App: (wgrad, dgrad, grad_bias)
    App->>App: "ctx.backward_objects = None  breaks retain_graph"
    App-->>LF: (wgrad, dgrad, grad_bias, None)
Loading

Comments Outside Diff (1)

  1. transformer_engine/pytorch/module/linear.py, line 1327-1336 (link)

    P1 ctx.backward_objects = None breaks re-entrant backward under retain_graph=True

    After the first backward call, ctx.backward_objects is set to None. If a second backward is triggered on the same graph node (i.e., retain_graph=True), the very first line of backwardbwd_args: LinearBwdArgs = ctx.backward_objects — yields None, and the immediately following bwd_args.grad_output = grad_output raises AttributeError: 'NoneType' object has no attribute 'grad_output'.

    In the old code, all backward state lived directly on ctx attributes, so a second call could proceed (saved tensors are preserved by PyTorch when retain_graph=True). A safer pattern is to null out the heavy fields (tensors, closures) inside bwd_args itself while leaving ctx.backward_objects pointing to the same (now mostly empty) bwd_args instance, so backward can always dereference the pointer and re-populate before running.

Reviews (2): Last reviewed commit: "[PyTorch] Linear: drop ctx.backward_obje..." | Re-trigger Greptile

Comment thread transformer_engine/pytorch/module/linear.py
Comment thread transformer_engine/pytorch/module/linear.py
Saved tensors, quantizers, weakrefs and main_grad closures referenced
from LinearBwdArgs survived until ctx GC, extending peak GPU memory
under retain_graph=True. Null out ctx.backward_objects right after
_linear_backward so they are released as soon as backward returns.

Signed-off-by: Pawel Gadzinski <pgadzinski@nvidia.com>
@pggPL
Copy link
Copy Markdown
Collaborator Author

pggPL commented May 8, 2026

/te-ci pytorch L1

Copy link
Copy Markdown
Collaborator

@timmoon10 timmoon10 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Most of my comments are just me talking through the design.

Comment on lines +1325 to +1328
# Drop all references held by bwd_args (saved tensors, quantizers, weakrefs,
# main_grad closure) so they don't outlive backward via ctx under retain_graph.
ctx.backward_objects = None
del bwd_args
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we destroy the cached state, we should also mark backward with function.once_differentiable.



@dataclass(slots=True)
class LinearFwdArgs:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This design is nicer than exposing a long list of positional args or passing in a tuple of args, but it is less nice than exposing kwargs. However, it has some advantages for this case:

  • torch.compile infrastructure will only need to handle one non-tensor arg to the autograd function.
  • Autograd functions do some processing on each arg, so minimizing the number of non-tensor args reduces CPU overhead.
  • _Linear.forward calls _linear_forward_impl and it doesn't need to be aware of the impl specifics.



@dataclass(slots=True)
class LinearBwdArgs:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In typical usage, LinearBwdArgs is somewhat redundant with the autograd context class. But we need it so that the forward-context-saving and backward are less entangled with autograd, which will allow code reuse when we implement an alternate torch.compile code path.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants