Skip to content

Beartype perimeter on user-facing API#90

Merged
hmgaudecker merged 5 commits into
af-estimatorfrom
feat/beartype-perimeter
May 15, 2026
Merged

Beartype perimeter on user-facing API#90
hmgaudecker merged 5 commits into
af-estimatorfrom
feat/beartype-perimeter

Conversation

@hmgaudecker
Copy link
Copy Markdown
Member

Stacked on the af-estimator branch.

Summary

Mirrors the pattern in pylcm PR #355: a per-exception BeartypeConf plus a beartype_init class decorator routes parameter-type violations at every documented entry point through a skillmodels-specific exception class, so callers can write narrowly-scoped except clauses against a stable hierarchy rather than catching beartype's framework exception.

Exceptions (src/skillmodels/exceptions.py)

Six TypeError subclasses, organised by perimeter:

  • ModelSpecInitializationErrorFactorSpec, AnchoringSpec, ModelSpec, Normalizations
  • OptionsInitializationErrorCHSEstimationOptions, AFEstimationOptions, AMNEstimationOptions
  • EstimationCallErrorget_maximization_inputs, get_filtered_states, estimate_af, estimate_amn, get_af_posterior_states, get_amn_posterior_states
  • InferenceCallErrorcompute_af_standard_errors, compute_amn_standard_errors
  • SimulationCallErrorsimulate_dataset, simulate_policy_effect
  • DiagnosticsCallErrordecompose_measurement_variance, summarize_measurement_reliability, plot_residual_boxplots, plot_likelihood_contributions, create_state_ranges, plot_correlation_heatmap, get_measurements_corr, get_quasi_scores_corr, get_scores_corr, univariate_densities, bivariate_density_contours, bivariate_density_surfaces, combine_distribution_plots, get_transition_plots, combine_transition_plots

All inherit from a common SkillmodelsInputError for callers that want to catch the whole hierarchy.

Decorator + config (src/skillmodels/_beartype_conf.py)

  • _conf(exc)BeartypeConf with violation_param_type=exc, strategy=BeartypeStrategy.On (full O(n) container scan; entry points are called rarely compared to the JIT-compiled hot path each one kicks off), is_pep484_tower=True.
  • beartype_init(conf) — class decorator that wraps only __init__. Bare @beartype on a class wraps every method, which surfaces non-public annotation drift on instance methods that has nothing to do with parameter validation at construction time.
  • Per-exception conf instances: MODEL_SPEC_CONF, OPTIONS_CONF, ESTIMATION_CONF, INFERENCE_CONF, SIMULATION_CONF, DIAGNOSTICS_CONF.

Side effects

  • _check_measurements's type-shape arm in common/check_model.py is now dead code: the tuple[tuple[str, ...], ...] annotation on FactorSpec.measurements makes beartype reject every malformed measurement structure at construction time. The function is kept (the report aggregator might still surface non-type issues a beartype container scan can't see), but the corresponding two tests in tests/test_check_model.py are rewritten to assert ModelSpecInitializationError at FactorSpec(...) time.
  • tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_value now asserts OptionsInitializationError from beartype's Literal check (which fires before AFEstimationOptions.__post_init__'s manual ValueError).
  • tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_results now asserts EstimationCallError; the prior body-level \"only one of\" ValueError is still in place but unreachable from this fixture, which passes the same AMN result to both parameters and trips the type guard on af_result first.
  • chs/filtered_states.py imports AFEstimationResult / AMNEstimationResult at runtime rather than under TYPE_CHECKING so beartype can resolve the annotation; ruff's TC003 autofix had been silently unforwarding the string forward refs.

Test plan

  • `pixi run -e tests-cpu pytest tests/ -q -k "not long_running"` — 529 passed, 1 deselected (same count as before this PR; no regressions)
  • `pixi run ty` — clean
  • `prek run --all-files` — clean

Out of scope (follow-up PRs)

  • Whole-package activation via `beartype.claw.beartype_package("skillmodels")` in `tests/conftest.py`. That probe would surface internal-helper annotation drift the same way pylcm's part-3 PR will, and is left for a separate review.
  • AGENTS-level conventions documentation for where future entry points should be decorated.

🤖 Generated with Claude Code

Stacked on the af-estimator branch (`2206212` series). Mirrors the
pattern in pylcm PR #355: a per-exception `BeartypeConf` plus a
`beartype_init` class decorator routes parameter-type violations at
every documented entry point through a skillmodels-specific
exception class, so callers can write narrowly-scoped `except`
clauses against a stable hierarchy rather than catching beartype's
framework exception.

Layout
------
* `src/skillmodels/exceptions.py` -- six `TypeError` subclasses,
  organised by perimeter (`ModelSpecInitializationError`,
  `OptionsInitializationError`, `EstimationCallError`,
  `InferenceCallError`, `SimulationCallError`,
  `DiagnosticsCallError`), all inheriting from a common
  `SkillmodelsInputError` for callers that want to catch the whole
  hierarchy in one go.
* `src/skillmodels/_beartype_conf.py` --
  `_conf(exc)` builds a `BeartypeConf` with
  `violation_param_type=exc`, `strategy=BeartypeStrategy.On` (full
  O(n) container scan; entry points are called rarely compared to
  the JIT-compiled hot path each one kicks off), and
  `is_pep484_tower=True`. `beartype_init(conf)` is a class decorator
  that wraps only `__init__` so non-public method-level annotation
  drift on instance methods does not surface at construction time.

Decoration sites
----------------
* `@beartype_init(MODEL_SPEC_CONF)` on `FactorSpec`, `AnchoringSpec`,
  `ModelSpec`, `Normalizations`.
* `@beartype_init(OPTIONS_CONF)` on `CHSEstimationOptions`,
  `AFEstimationOptions`, `AMNEstimationOptions`.
* `@beartype(conf=ESTIMATION_CONF)` on `get_maximization_inputs`,
  `get_filtered_states`, `estimate_af`, `estimate_amn`,
  `get_af_posterior_states`, `get_amn_posterior_states`.
* `@beartype(conf=INFERENCE_CONF)` on
  `compute_af_standard_errors`, `compute_amn_standard_errors`.
* `@beartype(conf=SIMULATION_CONF)` on `simulate_dataset`,
  `simulate_policy_effect`.
* `@beartype(conf=DIAGNOSTICS_CONF)` on
  `decompose_measurement_variance`,
  `summarize_measurement_reliability`,
  `plot_residual_boxplots`, `plot_likelihood_contributions`,
  `create_state_ranges`, `plot_correlation_heatmap`,
  `get_measurements_corr`, `get_quasi_scores_corr`,
  `get_scores_corr`, `univariate_densities`,
  `bivariate_density_contours`, `bivariate_density_surfaces`,
  `combine_distribution_plots`, `get_transition_plots`,
  `combine_transition_plots`.

Side effects of perimeter-only validation
-----------------------------------------
* `_check_measurements`'s type-shape arm in
  `common/check_model.py` is now dead code: the
  `tuple[tuple[str, ...], ...]` annotation on
  `FactorSpec.measurements` makes beartype reject every malformed
  measurement structure at construction time. The function is kept
  (the report aggregator might still surface non-type issues a
  beartype container scan can't see), but the corresponding two
  tests in `tests/test_check_model.py` are rewritten to assert
  `ModelSpecInitializationError` at `FactorSpec(...)` time
  instead of asserting a soft message in the aggregator output.
* `tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_value`
  now asserts `OptionsInitializationError` from beartype's
  `Literal` check (which fires before
  `AFEstimationOptions.__post_init__`'s manual ValueError).
* `tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_results`
  now asserts `EstimationCallError`; the prior body-level
  `"only one of"` ValueError is still in place but is unreachable
  from this fixture, which passes the same AMN result to both
  parameters and so trips the type guard on `af_result` first.
* `chs/filtered_states.py` imports `AFEstimationResult` /
  `AMNEstimationResult` at runtime rather than under
  `TYPE_CHECKING` so beartype can resolve the annotation; ruff's
  TC003 autofix had been silently unforwarding the string forward
  refs.

Verification
------------
* `pixi run -e tests-cpu pytest tests/ -q -k "not long_running"` --
  529 passed, 1 deselected (same count as before this commit; no
  regressions).
* `pixi run ty` -- clean.
* `prek run --all-files` -- clean.

Out of scope (follow-up PRs)
----------------------------
* Whole-package activation via `beartype.claw.beartype_package("skillmodels")`
  in `tests/conftest.py`. That probe would surface internal-helper
  annotation drift the same way pylcm's part-3 PR will, and is left
  for a separate review.
* AGENTS-level conventions documentation. The perimeter is in place;
  the rule for where to put the next decorator is "wherever the
  signature is documented to the user" -- to be expanded once the
  pattern has settled.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@codecov
Copy link
Copy Markdown

codecov Bot commented May 14, 2026

Codecov Report

❌ Patch coverage is 97.54902% with 5 lines in your changes missing coverage. Please review.
✅ Project coverage is 95.29%. Comparing base (2206212) to head (58344a5).
⚠️ Report is 5 commits behind head on af-estimator.

Files with missing lines Patch % Lines
src/skillmodels/common/fixed_constraint.py 78.94% 4 Missing ⚠️
src/skillmodels/af/jaxopt_backend.py 96.42% 1 Missing ⚠️
Additional details and impacted files
@@               Coverage Diff                @@
##           af-estimator      #90      +/-   ##
================================================
+ Coverage         95.05%   95.29%   +0.24%     
================================================
  Files               102      105       +3     
  Lines             10070    10207     +137     
================================================
+ Hits               9572     9727     +155     
+ Misses              498      480      -18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

hmgaudecker and others added 4 commits May 14, 2026 15:03
The previous fix (2206212) tried to address this by enabling x64 before
`import jaxopt`. The reproducer on sonny (jax 0.10.0, jaxopt 0.8.5)
shows that's insufficient: even with x64 on before any jaxopt code
runs and float64 inputs throughout, `LBFGSB.update`'s jit-compiled
`jnp.argsort` still emits an s32 reduction accumulator while the
surrounding scatter operand is built as s64. XLA's
`permutation_sort_simplifier` HLO pass rejects the mismatch with
`INVALID_ARGUMENT: Reduction function's accumulator shape at index 0
differs from the init_value shape: s32[] vs s64[]`.

Disabling just `permutation_sort_simplifier` via `XLA_FLAGS` fixes the
crash, keeps every other XLA optimisation intact, and is a no-op on
JAX < 0.10 (the pass doesn't exist there). The flag must be set before
`import jax` because XLA reads `XLA_FLAGS` once at backend init.

Applied in two places:
- `skillmodels/__init__.py`: the primary entry point. Appends to any
  pre-existing `XLA_FLAGS` so user flags aren't clobbered.
- `skillmodels/af/jaxopt_backend.py`: belt-and-suspenders for direct
  module imports that skip the package init.

The previous comment block tying the bug to "x64 off at import time"
was wrong about the root cause; replaced with the actual XLA pass
explanation. The `JAX_ENABLE_X64=1` setting is retained because the AF
pipeline assumes float64 throughout.

Verified end-to-end on sonny (jax 0.10): minimum jaxopt repro that
previously crashed now succeeds. Local jaxopt backend tests (7) and
full local suite (485 tests, jax 0.9) still pass.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The perimeter decorators (PR #90 / commit d7d29ea) covered only public
entry points. The claw activation in `tests/conftest.py` extends type
enforcement to every annotated callable in the package during the test
run, catching annotation drift on internal helpers that would otherwise
silently flow through.

Configuration:
- `is_pep484_tower=True` to mirror the perimeter conf so `int` satisfies
  `float`-typed parameters.
- `claw_skip_package_names=("skillmodels.chs.qr",)` because JAX's
  `@custom_jvp` decorator stores the secondary `.defjvp` setter on the
  wrapped object; beartype.claw's wrapping strips it.

Annotation drift fixes (sources of truth, not type-system theater):

* `FixedConstraintWithValue` moved to its own module
  `common/fixed_constraint.py`. `transition_functions.py` previously
  imported it under `if TYPE_CHECKING:` to avoid a circular import with
  `constraints.py`; beartype.claw can't resolve those forward refs at
  decoration time. The leaf type now lives where both modules can pull
  it without a cycle.
* Same TYPE_CHECKING removal for `ModelSpec` in `af/types.py` and
  `amn/types.py`.
* JAX-traced helpers (`_at_node`, `_chain_one_component`,
  `_compute_investment`, kalman / likelihood entry points, transition
  pipeline plumbing): annotations relaxed to accept `Array | np.ndarray`
  / `float | Array` / `int | np.integer` where the runtime contract is
  genuinely mixed. JAX vmap traces ints as `BatchTracer`; numpy and
  jax arrays interconvert freely through these signatures.
* `MixtureComponent`, `ConditionalDistribution`, `ChainLink` dataclass
  fields: accept `np.ndarray` alongside `jax.Array` since estimators
  fill them with both.
* `TransitionInfo.param_names`: now built with explicit `tuple()`
  conversion at the boundary in `process_model.py` so the
  `MappingProxyType[str, tuple[str, ...]]` annotation actually holds.
* `get_has_endogenous_factors` now casts the pandas `.any()` result to
  `bool` instead of relying on a `# ty: ignore`.
* `NDArray[np.floating[Any]]` (which beartype doesn't accept as a dtype
  hint) replaced with `NDArray[np.float64]` in `chs/process_debug_data`
  and `common/visualize_factor_distributions`.
* `NDArray[np.floating]` in `common/simulate_data` widened to
  `NDArray[np.floating] | Array`.
* Internal duck-typed validators (`_check_anchoring`, `_process_factors`)
  re-typed as `Any` with `# noqa: ANN401` and an inline comment; they
  exist precisely to take partially-built objects.
* `_aug_periods_from_period`: `dict[int, int]` → `Mapping[int, int]`
  (production passes a `MappingProxyType`).

Tests:
- `tests/test_check_model.py`: re-add `# ty: ignore` on the two
  `FactorSpec(measurements=...)` calls that intentionally pass `list`
  where `tuple` is required; they verify the beartype perimeter
  catches the shape error.
- `tests/test_transition_functions.py::test_constant`: rewritten to
  pass real JAX arrays now that the claw type-checks `constant`.
- Stale `# ty: ignore[invalid-argument-type]` directives stripped from
  `test_check_model.py`, `test_correlation_heatmap.py`,
  `test_process_debug_data.py` (and one in `simulate_data.py`) — the
  annotations they were silencing have been relaxed.

Verification: 495 tests pass with claw enabled (`pixi run -e tests-cpu
tests`); `pixi run ty` clean; `prek run --all-files` clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI run 25868871644 surfaced 14 failures in `tests/test_af_inference.py`
that local runs missed (the file is unmarked but the local sweep
excluded similarly-shaped tests via `-k "not long_running and not
end_to_end"`).

All failures had the same root cause as the bulk of the prior
claw-activation diff: AF transition / chain helpers were annotated
with strict `jax.Array` parameters, but the runtime path through
`af.inference` constructs `prev_distribution`, chain link inputs, and
chol/mean blocks from `np.ndarray`. Beartype rejected them under
the now-active claw.

Relaxed sites:
- `af.likelihood.af_per_obs_loglike_transition` /
  `af_loglike_transition` / `_integrate_transition_chain`:
  `prev_distribution` widened from `dict[str, Array]` to
  `Mapping[str, Array | np.ndarray]`. `Mapping` (covariant) lets
  callers pass `dict[str, Array]` without an explicit cast.
- `af.likelihood._map_over_obs`: `*xs: Array` → `*xs: Array | np.ndarray`.
- `af.likelihood._integrate_transition_single_obs`: `obs_cond_weights`,
  `obs_cond_means`, `cond_chols` widened to `Array | np.ndarray`.
- `af.likelihood._rebuild_chain_at_period`: `initial_mean`,
  `initial_chol` widened to `Array | np.ndarray`. Internal `theta`
  bound through `jnp.asarray(...)` so downstream `_compute_investment`
  still sees an `Array`.
- `af.likelihood._compute_investment`: `inv_eq_params`, `inv_sds`
  widened (covered earlier; ty-narrowing followed naturally).

Verification:
- `pixi run -e tests-cpu pytest tests/test_af_inference.py` — 14 / 14
  pass (was 1 failed + 13 errors).
- `pixi run ty` clean.
- `prek run --all-files` clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously stopped on `max|projected_grad| < tol` only. That's not how
scipy_lbfgsb stops in practice: scipy stops on EITHER `gtol_abs` OR a
relative function-value decrease `ftol_rel`. On the skill-formation
likelihoods used in the Monte Carlo benchmarks, the loglikelihood goes
locally flat before the gradient does, so scipy's ftol channel fires
~100% of the time and gradient-norm < 1e-5 is essentially never the
actual stopping criterion in production.

Without the ftol channel, jaxopt would grind down the gradient while
scipy declared success at the same point — a fake apples-to-oranges
asymmetry that made the AF-jaxopt vs AF-optimagic timing comparison
meaningless.

Implementation: drop jaxopt's built-in `run()` and drive the solver
through an explicit `init_state` + `update` loop with the same
gtol-OR-ftol stop. Default values now match scipy_lbfgsb's defaults
(`gtol_abs=1e-5`, `ftol_rel=2.22e-9`, `maxiter=15000`). The wrapper
accepts both canonical scipy keys (`convergence_gtol_abs`,
`convergence_ftol_rel`, `stopping_maxiter`) and the historical jaxopt
keys (`tol`, `maxiter`) so the same `optimizer_options` dict works
for either backend.

This makes the two LBFGSB implementations stop on byte-identical
rules; the only remaining differences are internal (line search,
step acceptance, curvature-pair filtering) — which is the comparison
that's actually interesting.

Verified: 7 jaxopt_backend tests still pass; ty clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@hmgaudecker hmgaudecker merged commit 58344a5 into af-estimator May 15, 2026
4 checks passed
@hmgaudecker hmgaudecker deleted the feat/beartype-perimeter branch May 15, 2026 09:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant