CPU overhead optimizations for te autocast#2957
CPU overhead optimizations for te autocast#2957vthumbe1503 wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Greptile SummaryThis PR reduces CPU overhead on Grace/GB200 systems by caching recipe
Confidence Score: 5/5The changes are straightforward performance optimizations with no functional regressions for the normal usage pattern. Both changes (repr caching and class-based context manager) are behaviorally equivalent to the code they replace for all standard single-use patterns. The No files require special attention beyond the minor Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant autocast
participant FP8GlobalStateManager
participant Recipe
User->>autocast: __init__(enabled, recipe, ...)
note over autocast: stores args, _fp8_state=None
User->>autocast: __enter__()
autocast->>Recipe: __repr__() [if needed for key]
Recipe-->>autocast: cached _cached_repr
autocast->>FP8GlobalStateManager: get_autocast_state()
FP8GlobalStateManager-->>autocast: fp8_state snapshot
autocast->>FP8GlobalStateManager: autocast_enter(enabled, recipe, ...)
autocast-->>User: self
note over User: training step
User->>autocast: __exit__(exc_type, exc_val, exc_tb)
autocast->>FP8GlobalStateManager: set_autocast_state(fp8_state)
autocast->>FP8GlobalStateManager: autocast_exit(enabled, _graph)
Reviews (2): Last reviewed commit: "[pre-commit.ci] auto fixes from pre-comm..." | Re-trigger Greptile |
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | ||
| if recipe_repr is None: | ||
| recipe_repr = str(recipe) | ||
| group_id = id(group) if group is not None else 0 | ||
| return f"{recipe_repr}|{group_id}" |
There was a problem hiding this comment.
Key format change could produce ambiguous keys
The new key format f"{recipe_repr}|{group_id}" uses | as a separator without escaping. If a future recipe's __repr__ ever emits a | character, two distinct (recipe, group) pairs could map to the same string. The old str(tuple) format was unambiguous because it quoted the recipe repr. A safer pattern uses a separator that cannot appear in repr output, or encodes the parts deterministically.
There was a problem hiding this comment.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | |
| if recipe_repr is None: | |
| recipe_repr = str(recipe) | |
| group_id = id(group) if group is not None else 0 | |
| return f"{recipe_repr}|{group_id}" | |
| group_id = id(group) if group is not None else None | |
| return f"recipe=({str(recipe)}),group={group_id}" |
| def __enter__(self) -> "autocast": | ||
| if self._enabled: | ||
| check_recipe_support(self._recipe) | ||
| # Save current state so we always restore it on exit. | ||
| self._fp8_state = FP8GlobalStateManager.get_autocast_state() | ||
| FP8GlobalStateManager.autocast_enter( | ||
| enabled=self._enabled, | ||
| calibrating=self._calibrating, | ||
| fp8_recipe=self._recipe, | ||
| fp8_group=self._amax_reduction_group, | ||
| _graph=self._graph, | ||
| ) | ||
| return self | ||
|
|
||
| FP8GlobalStateManager.autocast_enter( | ||
| enabled=enabled, | ||
| calibrating=calibrating, | ||
| fp8_recipe=recipe, | ||
| fp8_group=amax_reduction_group, | ||
| _graph=_graph, | ||
| ) | ||
| try: | ||
| yield | ||
| finally: | ||
| FP8GlobalStateManager.set_autocast_state(fp8_state) | ||
| FP8GlobalStateManager.autocast_exit(enabled, _graph=_graph) | ||
| def __exit__(self, exc_type, exc_val, exc_tb) -> None: | ||
| FP8GlobalStateManager.set_autocast_state(self._fp8_state) | ||
| FP8GlobalStateManager.autocast_exit(self._enabled, _graph=self._graph) | ||
| # Do not suppress exceptions. | ||
| return None |
There was a problem hiding this comment.
Nested reuse of the same instance silently corrupts state
The old generator-based implementation raised RuntimeError: generator already executing if you tried to enter the same context manager object twice concurrently. The new class-based implementation silently accepts nested reuse, but the second __enter__ call overwrites self._fp8_state with the state captured inside the first context, so the outer __exit__ restores the wrong state permanently.
ctx = autocast(enabled=True, recipe=recipe)
with ctx: # _fp8_state = pre_context_state
with ctx: # _fp8_state = state_inside_first_block ← overwrites!
pass # __exit__: restores state_inside_first_block
# _fp8_state is now state_inside_first_block
# __exit__: restores state_inside_first_block, NOT pre_context_state ← bugAdding a guard in __enter__ would preserve the old safety behavior:
def __enter__(self) -> "autocast":
if self._fp8_state is not None:
raise RuntimeError("autocast context manager cannot be entered more than once concurrently")
...| def __repr__(self) -> str: | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result |
There was a problem hiding this comment.
_cached_repr stored outside declared dataclass fields
MMParams is @dataclass(frozen=True). Storing _cached_repr via object.__setattr__ bypasses the frozen guard correctly in CPython, but _cached_repr is not a declared dataclass field — it won't appear in dataclasses.fields(), dataclasses.asdict(), dataclasses.astuple(), or copy.replace(). If downstream code serializes or copies an MMParams instance, the cached repr would be lost silently. Documenting this with a comment or declaring it as field(init=False, repr=False, compare=False) would make the intent clearer. The same applies to QParams.
There was a problem hiding this comment.
I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.
| def __repr__(self) -> str: | ||
| cached = self.__dict__.get("_cached_repr") | ||
| if cached is not None: | ||
| return cached | ||
| result = f"MMParams(use_split_accumulator={self.use_split_accumulator})" | ||
| object.__setattr__(self, "_cached_repr", result) | ||
| return result |
There was a problem hiding this comment.
I see that this is why we're doing the funny accesses with __dict__. I agree that bypassing frozen=True is iffy, so I wonder if we could set _cached_repr in __post_init__? If the class is frozen, its repr must also be frozen and I don't see a benefit in lazy evaluation.
| # changes. This makes repeated ``str(recipe)`` calls (e.g. on the hot | ||
| # path in ``FP8GlobalStateManager.get_unique_autocast_key``) essentially | ||
| # free after the first call. | ||
| _cached_repr: Optional[str] = None |
There was a problem hiding this comment.
Three problems:
_cached_repris being set as a class attr, not an instance attr.- Accessing
_cached_reprvia__dict__is non-standard and bug-prone. - Splitting the cache logic between the base class and child classes results in code duplication and more risk of bugs, especially if it involves non-standard
__dict__accesses.
What if we concentrated the caching logic in the base class:
class Recipe:
def __init__(self) -> None:
self._cached_repr: Optional[str] = None
@abc.abstractmethod
def _make_repr(self) -> str:
...
def __repr__(self) -> str:
if self._cached_repr is None:
self._cached_repr = self._make_repr()
return self._cached_repr
...
class DelayedScaling(Recipe):
def _make_repr(self) -> str:
return f"..."| # directly getting the cached repr is about 40 ns faster than str(recipe) | ||
| # on grace systems. |
There was a problem hiding this comment.
This is good to mention in the PR description, but not that useful in the code itself. Profiling becomes outdated once we move on to the next architecture.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | ||
| if recipe_repr is None: | ||
| recipe_repr = str(recipe) | ||
| group_id = id(group) if group is not None else 0 | ||
| return f"{recipe_repr}|{group_id}" |
There was a problem hiding this comment.
| recipe_repr = recipe.__dict__.get("_cached_repr") if recipe is not None else None | |
| if recipe_repr is None: | |
| recipe_repr = str(recipe) | |
| group_id = id(group) if group is not None else 0 | |
| return f"{recipe_repr}|{group_id}" | |
| group_id = id(group) if group is not None else None | |
| return f"recipe=({str(recipe)}),group={group_id}" |
| # Class-based context manager (instead of ``@contextmanager`` from contextlib) | ||
| # to avoid the ~0.5us / invocation overhead of contextlib's generator-driven | ||
| # ``GeneratorContextManager``. ``__slots__`` further avoids per-instance | ||
| # dict allocation. |
There was a problem hiding this comment.
Why are we mentioning the context manager here? It makes sense for this PR, but once the code is merged it will be completely random. This comment should explain what we are doing with __slots__, and we should explain the custom context manager logic in __enter__ and __exit__.
| # Do not suppress exceptions. | ||
| return None |
There was a problem hiding this comment.
Nit: The function already returns None and the comment is trivially true (all Python outside of a try statement is not suppressing exceptions).
| # Do not suppress exceptions. | |
| return None |
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
for more information, see https://pre-commit.ci
Description
te-autocast has quite a bit of CPU overheads on Grace Systems.
Here are the results on GB200 after the optimizations
Without Optimizations

Optimization1 --> Cache recipe string representation for getting unique autocast key. Also directly accessing the cached representation of recipe is 40 ns faster than going through a str(recipe)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: