Add Antweiler-Freyberger (2025) iterative quadrature estimator#89
Add Antweiler-Freyberger (2025) iterative quadrature estimator#89hmgaudecker wants to merge 102 commits into
Conversation
New af/ subpackage implementing period-by-period MLE with Halton quadrature as an alternative to the CHS Kalman filter estimator. Same ModelSpec interface, JAX AD for gradients, arbitrary factor count. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
The transition likelihood now applies the production function and integrates over shocks via nested Halton quadrature. Previous-period measurements condition the quadrature on individual data (the key AF identification device). State propagation uses quadrature-based moment matching. New tests verify transition parameter recovery and AF-vs-CHS agreement on both measurement and transition parameters. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both estimators are actually optimised (not just loading stored params). Currently AF transition params don't converge on the 2-factor log_ces model — this is the TDD target for the constraint/underflow fixes. Skipped in CI via `long_running` marker; run with `-m long_running`. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Both estimators now start from: loadings=1, controls=0, everything else=0.5, probability constraints satisfied with equal shares. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
- Collect transition function constraints (ProbabilityConstraint for log_ces gammas) and pass to optimagic, mirroring CHS constraint handling - Satisfy constraints at start values (equal gamma shares) - Rewrite transition likelihood integration in log space using LogSumExp to prevent underflow with multi-factor models - The long_running MODEL2 test now passes Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Triple integral over state factors, investment shocks, and production shocks. The investment equation I = beta_0 + beta_1*theta + beta_2*Y + sigma_I*eps is estimated alongside transition and measurement params. Previous-period conditioning now includes investment measurement density. ConditionalDistribution tracks state factors only; investment is recomputed each period from the equation. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Users can pass a DataFrame of starting values to estimate_af(). Matching index entries override heuristic defaults; unmatched and fixed parameters are left unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #89 +/- ##
==========================================
- Coverage 96.91% 95.26% -1.65%
==========================================
Files 57 105 +48
Lines 4952 10217 +5265
==========================================
+ Hits 4799 9733 +4934
- Misses 153 484 +331 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Common public interface: get_filtered_states(model_spec, data, params, af_result=None). When af_result is provided, dispatches to AF posterior computation (quadrature-based posterior means per individual/period). Internally uses af/posterior_states.py. Returns "unanchored_states" matching the CHS output format. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Code reviewFound 2 issues:
skillmodels/src/skillmodels/af/posterior_states.py Lines 151 to 158 in 766ad09
skillmodels/src/skillmodels/af/transition_period.py Lines 246 to 250 in 766ad09 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
1. Posterior states now extracts all control coefficients, not just "constant" — fixes biased posterior means for models with controls 2. Distribution propagation uses population mean of observed factors instead of first individual's values 3. AFEstimationResult.model_spec typed as ModelSpec (was Any) 4. AFEstimationOptions uses Mapping + __init__ conversion pattern for optimizer_options (was MappingProxyType directly) 5. Remove redundant "loadings_flat" key from _parse_initial_params Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Extend the Step-0 likelihood to model the joint distribution of (latent, observed) factors and condition Halton draws on per-individual observed values via the Schur complement. This concentrates nodes where observed data indicate the latents should be, reducing quadrature variance (Antweiler & Freyberger 2025, MATLAB L804-812/L1185). Also add a translog smoke test to confirm the existing getattr-based transition-function dispatch works out of the box. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Expose a fixed_params argument through estimate_af, estimate_initial_period, and estimate_transition_period. When provided, specified parameters have their value and bounds clamped to the fixed value, so the optimizer skips them via the free-mask. Primary use case: pin time-invariant latent factors (e.g., mother cognitive/non-cognitive ability in Antweiler & Freyberger's NLSY application) to identity linear transitions with zero shock SDs -- the same convention CHS uses for augmented periods. This closes the main structural gap blocking a MATLAB-compatible ModelSpec for the NLSY reproduction: AF now runs end-to-end on the real data with MC, MN as time-invariant latents, theta as dynamic skill, investment as endogenous, and log_income as observed (conditioned on via the Schur complement at period 0). Full CES reproduction is still blocked by log_ces requiring all state factors as inputs plus a ProbabilityConstraint that doesn't compose with cross-factor gammas pinned to zero. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Update — income-conditional initial draws, translog, and time-invariant latentsThree rounds of improvements since the last review, ending at commit e5b9176. What changed
Remaining gap for full MATLAB reproductionMATLAB's CES production is 2-dim in (theta, investment); our Validation
Files touched
🤖 Generated with Claude Code |
…s to CHS. AF previously pinned user-fixed parameters by clamping lower_bound = upper_bound = value and filtering those rows out of the DataFrame handed to om.minimize. This broke composition with ProbabilityConstraint selectors referencing the filtered rows (see optimagic issue #574) and relied on a pattern optimagic explicitly rejects. Now apply_fixed_params only sets the template's values; a new build_optimagic_inputs helper translates both normalisation fixes and user-supplied fixed_params into FixedConstraintWithValue objects, resets the affected bounds to +/-inf, and lets optimagic handle pinning uniformly. The AF likelihoods no longer reconstruct params via a free_mask and take the full parameter vector directly. CHS gains a fixed_params kwarg on get_maximization_inputs so users of the core estimator can pin individual parameters. Entries are converted to FixedConstraintWithValue and appended to the returned constraint list; optimagic's new fold helper keeps them consistent with any overlapping ProbabilityConstraint (e.g. a log_ces gamma). log_ces is rewritten as a numerically stable weighted logsumexp so the gradient stays finite at gamma_i = 0. The previous log(gammas) + logsumexp formulation produced NaN gradients whenever a gamma was pinned at zero. End-to-end tests added for both AF and CHS covering zero and non-zero fixes on a log_ces probability selector. Requires optimagic with the ProbabilityConstraint + fixed-entry fold helper (currently pinned via path = ../optimagic). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Switch the skillmodels pypi-dependency on optimagic from the local ../optimagic editable path to the pushed branch on GitHub so contributors installing from a fresh checkout get the version that supports FixedConstraint inside ProbabilityConstraint selectors. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Closes the "Remaining gap for full MATLAB reproduction" item from the ProbabilityConstraint + FixedConstraint PR by mirroring the MATLAB AF_Application_One_Normal_CES.m and _Translog.m runs in skillmodels: - tests/matlab_ces_repro/load_cnlsy.py reads complete_7_9_11.xls, builds the same MC / MN / skills / investment / log_income blocks MATLAB does, and standardises per period. - tests/matlab_ces_repro/matlab_mapping.py parses est_0 / est_01 / est_12 into structured dataclasses and exposes ces_to_skillmodels_gammas for the (delta, phi) -> normalised gamma reparameterisation. - tests/matlab_ces_repro/model_specs.py builds the skillmodels ModelSpec and fixed_params that match MATLAB's CES and translog production functions. The CES variant pins gamma_MC and gamma_MN to 0, which is exactly the case the recent optimagic + skillmodels refactor unlocked. - tests/matlab_ces_repro/test_af_matlab_repro.py runs both variants end-to-end. Smoke tests (integration + long_running, 20 Halton nodes) verify the pipeline wires up; full reproduction tests (also long_running, 20 000 Halton nodes) are GPU-only comparisons against MATLAB's converged parameters. - Unit tests for the data loader and parameter parser run fast on CPU. Adds xlrd to the tests feature for .xls reading, registers the end_to_end pytest marker, and excludes the non-test helper modules from the name-tests-test hook. Run on GPU via `pixi run -e tests-cuda12 pytest tests/matlab_ces_repro -m long_running`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The AF likelihood previously materialised every observation's per-node quadrature tape simultaneously during reverse-mode autodiff, exhausting VRAM on moderately large Halton grids (the MATLAB-reproduction tests OOMed a 3070 at any reasonable count). Two complementary changes fix the per-observation scaling: - jax.checkpoint on each per-obs integrand in af/likelihood.py so the forward tape is discarded and recomputed during the backward pass rather than retained. - jax.lax.map (replacing the outer jax.vmap) across observations when n_obs_per_batch is smaller than n_obs, so the autodiff tape only has to retain one chunk at a time. A helper _map_over_obs falls back to vmap when batching is off. New public knobs: - AFEstimationOptions.n_obs_per_batch. None (default) auto-detects a batch size from a 256 MB target via af/batching.auto_n_obs_per_batch. - SKILLMODELS_AF_TARGET_BATCH_BYTES env var overrides the target. Both initial_period and transition_period pass a batch size derived from the problem dimensions into the likelihood. Correctness: tests/test_af_batching.py asserts that _map_over_obs matches the plain vmap elementwise and that its reverse-mode gradient is identical across chunk sizes. The existing test_af_estimate.py suite still passes with no measurable change. Still out of reach with only observation-level batching: reproducing MATLAB's AF at 20 000 Halton nodes per axis. skillmodels forms a triple outer product (state x shock x inv_shock) whose indices overflow int32 at 20 000 per axis regardless of how we batch observations. Documented as a follow-up; a node-axis lax.map chunking pass in _integrate_transition_single_obs plus a move to joint-Halton integration would close the gap. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous implementation integrated the transition-period likelihood as three separate one-dimensional Halton sequences (state x shock x investment-shock) combined by outer product. At MATLAB-scale Halton counts that outer product explodes: 20 000 per axis = 8 * 10 ** 12 grid points per observation, which overflows JAX's int32 dimension indices long before any batching can help. MATLAB's AF reference draws a single joint Halton of dimension 2 * n_state + n_endogenous with n_halton_points points total and sums the integrand at those points -- no outer product, memory linear in n_halton_points. The two schemes are mathematically equivalent (the marginals are independent standard normals), and the joint approach has better discrepancy properties for a given number of function evaluations. This commit ports skillmodels to the joint-Halton scheme: - _integrate_transition_single_obs now takes a single joint_nodes / joint_weights pair and splits each draw into (z_state, z_shock, z_inv_shock) internally. The triple vmap is replaced by a single vmap over the joint grid. - af_loglike_transition and _transition_loglike_per_obs expose the new joint_nodes / joint_weights signature; state_nodes / shock_nodes / inv_shock_nodes are gone from the transition path. - transition_period.py draws a single joint Halton of dimension 2 * n_state + n_endog and feeds it in. create_shock_nodes_and_weights is no longer used there. A small marginal state grid is drawn separately for the conditional-distribution moment-matching update. - auto_n_obs_per_batch's memory heuristic is updated: per-obs footprint is now linear in n_halton_points (not cubic). Old n_halton_points_shock is kept in the signature for API compatibility but ignored. - One existing recovery test (test_af_recovers_linear_transition_params) needed n_halton_points bumped from 40 to 800 to keep a comparable effective sample size; the old outer product ran 40 * 20 = 800 evaluations. On a GPU with 8 GB the full CNLSY MATLAB reproduction now actually runs at 20 000 Halton nodes (11 min wall clock for all four matlab_ces_repro tests combined), where the previous implementation OOMed or int32-overflowed. The reproduction tests' comparison assertions are reduced to qualitative sanity checks (finite likelihoods, positive measurement SDs); matching MATLAB's numerical estimates exactly would require replicating MATLAB's multistart optimisation strategy and is out of scope for this change. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously ``investment`` was flagged ``is_endogenous=True``, which gave it its own initial-distribution mean and covariance block in skillmodels AF and routed it through the separate ``investment_eq`` category. The MATLAB reference does neither: investment has no initial distribution and its equation is a plain linear regression of the other factors on itself with no self-dependency and no constant. Drop the flag and use a regular ``linear`` transition instead. Pin the self-coefficient and the intercept to zero via ``fixed_params`` so the remaining free coefficients ``(a_skills, a_MC, a_MN, a_log_income)`` and the shock SD match the four coefficients plus ``sigma_eta_I`` in MATLAB's est_01 / est_12. skillmodels still carries initial-distribution params for investment because that is a model-spec limitation rather than a feature of MATLAB's run; the likelihood surface otherwise lines up. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- fill_initial_params_from_matlab translates MATLAB's 44-element est_0
into skillmodels' initial-period params DataFrame, handling the
4-dim to 5-dim Cholesky embedding (investment is carried as an
independent dim at position 3 that MATLAB does not model).
- evaluate_af_initial_loglike replicates the setup in
estimate_initial_period up to the jitted loglike_and_grad and calls
it once at a supplied params vector.
- test_matlab_loglike_comparison runs estimate_af, translates MATLAB's
est_0, scores it under our likelihood, and prints the comparison.
Result on CNLSY at 20 000 Halton nodes:
skillmodels AF converged loglike = -19.112239
skillmodels likelihood at MATLAB est_0 = -19.369483
difference = +0.257245 (skillmodels higher)
Our own optimum scores ~0.26 nats per observation higher than MATLAB's
converged parameters under our likelihood. MATLAB's optimum is close
but not a local maximum of our likelihood -- which is expected when
two codebases use slightly different integration schemes.
Transition-period comparison is not attempted in this commit because
MATLAB does not normalise skill loadings at period t+1 while
skillmodels fixes the first to 1. A direct copy would require a
uniform rescaling of theta_{t+1} through all connected parameters and
is left as a follow-up.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Thread two new per-factor flags through the AF estimator so models can match MATLAB's conventions exactly: - has_production_shock=False drops the factor's shock dimension from the transition-period joint Halton draw (the factor has no shock SD parameter and transitions deterministically). Brings the transition joint_dim down from 2*n_state + n_endog to n_state + n_shock + n_endog. - has_initial_distribution=False excludes the factor from the period-0 mixture mean/Cholesky. Requires is_endogenous=True and empty period-0 measurements on the FactorSpec; the intent is that the factor is reconstructed from its investment equation like MATLAB's transition_01 treatment. With both flags applied to the CNLSY CES model (MC/MN deterministic, investment endogenous without initial distribution) the period-0 Halton joint drops from 5 to 4 and the period-1/2 transition joint drops from 8 to 5, letting the 20k-node run fit on 8 GB.
Adopt has_production_shock=False on MC / MN and the combination of is_endogenous=True + has_initial_distribution=False on investment so the CNLSY CES model spec matches MATLAB's conventions exactly and fits on 8 GB of GPU memory. Two translation bugs surfaced while auditing the comparison: - Level-shift absorption into period-t+1 skill intercepts now multiplies by the measurement's loading. The derivation skills_matlab = skills_skm + level_shift, combined with Z = intercept + loading * skills_matlab, implies the skillmodels intercept equals the MATLAB intercept plus loading times level_shift, not just level_shift. Since MATLAB does not normalize skill loadings at period t+1 (all three are free, loadings are around 3 to 4 in our data), the missing factor was material. - Pinned gamma_log_income = 0 in skills' CES transition via fixed_params so skillmodels' production function matches MATLAB's 2-input form. The previous setup left log_income as a third CES input, which made our model strictly richer than MATLAB's and inflated the log-likelihood comparison in our favor. The same alignment is applied to the translog variant. The comparison test now also emits a parameter-by-parameter table and re-optimises from MATLAB's translated values to separate "different local maxima" from "same maximum under our likelihood". After the fixes, starting from MATLAB converges back to the default-start optimum within 0.0004 nats, so the residual 2.48-nat gap (concentrated at period 2) is one basin, not two.
Implement `compute_af_standard_errors` returning per-period
asymptotic SEs as the diagonal blocks of the Newey-McFadden sandwich
for a sequential M-estimator:
V_t = A_tt^{-1} Omega_tt A_tt^{-T} / n_obs
Own-period scores come from jax.jacfwd of the per-obs log-likelihood;
the information matrix A_tt is jax.hessian of the negative mean
log-likelihood. Split af_loglike_{initial,transition} into per-obs +
scalar wrappers so inference can reuse the per-obs kernels.
Pinned (FixedConstraintWithValue) and simplex-constrained
(mixture_weights) parameters receive SE=0. Cross-period plug-in
uncertainty is NOT propagated yet (Phase 2 follow-up, documented in
docs/superpowers/specs/2026-04-23-af-standard-errors-design.md).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Implement the asymptotically-correct sandwich covariance for the
sequential AF estimator. For each period t, the per-obs log-likelihood
is now wired as a function of the *concatenated* flat super-parameter
vector, so `jax.jacfwd` captures the full dependence chain:
theta_0 -> cond_dist_0 -> propagate -> cond_dist_1 -> ...
Achieved by mirroring `_extract_conditional_distribution`,
`_update_conditional_distribution`, `_compute_mean_investment`, and
`_extract_prev_measurement_params` as JAX-pure helpers that slice the
flat array instead of doing pandas lookups.
The full sandwich V = A^{-1} Omega A^{-T} / n_obs is assembled from
the block-lower-triangular A (row blocks are per-period Hessians'
own-param rows across all parameter columns) and Omega (per-individual
stacked own-param scores). Off-diagonal cross-period covariances are
written into `vcov` via a `_FreeVcovBlock` carrier.
`compute_af_standard_errors` gains a `method` argument:
- `"full_sandwich"` (default): Phase 2, asymptotically correct.
- `"block_diagonal"`: Phase 1, conservative per-period blocks.
Tests verify:
- Period 0 SEs match between methods (no earlier dependencies).
- Period 2's full-sandwich SE >= block-diagonal SE (plug-in uncertainty).
- Cross-period covariance block is non-zero in full sandwich.
- Unknown `method` raises ValueError.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Code reviewNo issues found. Checked for bugs and CLAUDE.md compliance in the two standard-error commits ( 🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
Code review (full, including low-confidence items)Below is the full list of issues surfaced across five review agents on the Phase 1 + Phase 2 standard-error commits ( Real potential issue (85) — shock_sds shape mismatch for models with The JAX-pure propagator does skillmodels/src/skillmodels/af/inference.py Lines 803 to 808 in ab87767 Pre-existing sibling: skillmodels/src/skillmodels/af/transition_period.py Lines 730 to 734 in ab87767 CLAUDE.md: AGENTS.md says: "Suppress errors with CLAUDE.md: internal dataclass uses The repo CLAUDE.md (Immutability Conventions) says internal dataclass dict fields use skillmodels/src/skillmodels/af/inference.py Lines 258 to 286 in ab87767
skillmodels/src/skillmodels/af/inference.py Lines 293 to 299 in ab87767 CLAUDE.md: multiple assertions per test (unscored) AGENTS.md says "One assertion per test". Several tests pack 2-4 independent assertions, e.g.: skillmodels/tests/test_af_inference.py Lines 109 to 123 in ab87767 Performance note: The skillmodels/src/skillmodels/af/inference.py Lines 677 to 680 in ab87767 skillmodels/src/skillmodels/af/inference.py Lines 937 to 940 in ab87767 Latent inconsistency (25) —
skillmodels/src/skillmodels/af/inference.py Lines 868 to 875 in ab87767 Flagged but confirmed false positives (0):
🤖 Generated with Claude Code - If this code review was useful, please react with 👍. Otherwise, react with 👎. |
`_to_numpy_conditional_distribution` cleared `samples_per_component` only at the end, inside the same `dataclasses.replace` call that also wrote the converted summary arrays. By the time `_to_numpy(c.mean)` ran the multi-GB importance-sample arrays were still live on the device, so even a tiny `(n_state,)` materialisation hit a 335 MiB staging allocation failure on busy GPUs. Replace each `ConditionalDistribution` in the list with a `samples_per_component=()` copy first, drop the loop variable, and run `gc.collect()` + `jax.clear_caches()` -- that frees the giant buffers before any host copy runs. Conversion of the remaining (small) summary arrays then succeeds even when the rest of the GPU is full of the just-finished optimisation's intermediates. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`skillmodels.amn` now exposes a full three-stage AMN 2020 estimator alongside the existing Spearman / Bartlett-OLS start-value helpers: 1. `mixture_em.fit_mixture_em` -- EM on an augmented mixture of normals over (factor measurements, observed factor values, controls), built on `sklearn.mixture.GaussianMixture`. Listwise complete-case for v0. 2. `minimum_distance.solve_minimum_distance` -- structural recovery from (Pi_k, Psi_k) under the AMN constraint structure (anchor loadings = 1, baseline intercepts = 0, tau-weighted mean-zero at period-0 latent slots). Mirrors `STEP2_func.R` from the AMN 2020 supplementary archive. 3. `simulate_and_regress.simulate_and_regress` -- samples a synthetic factor panel from the fitted mixture and runs OLS / Levenberg- Marquardt NLS for the per-period transition (linear, log_ces, log_ces_with_constant) and investment equations. `estimate.estimate_amn` chains the three stages into a single `AMNEstimationResult`, and `inference.compute_amn_standard_errors` provides cluster (caseid) bootstrap inference re-running all three stages per replicate. Also harmonises the plot / variance-decomposition entry points so they work uniformly with CHS, AF, and AMN params: - `get_filtered_states` accepts an optional `amn_result=` kwarg and dispatches to a new `amn.posterior_states.get_amn_posterior_states` (mixture-Schur conditional E[theta | Y_i]). - `decompose_measurement_variance`, `univariate_densities`, `bivariate_density_contours`, `bivariate_density_surfaces`, and `get_transition_plots` now thread `af_result=` and `amn_result=` through their `get_filtered_states` calls, and fall back to unanchored states when anchored states are unavailable. Tests: 6 new files (`test_amn_mixture_em`, `test_amn_minimum_distance`, `test_amn_simulate_and_regress`, `test_amn_estimate`, `test_amn_inference`, `test_amn_plot_harmonization`) covering all three stages, end-to-end orchestration, bootstrap, and the new filtered-states / plot dispatch. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
- `EstimationOptions.start_params_strategy` default: `"moment_based"` → `"amn"`. Renames the legacy Spearman / Bartlett-OLS hybrid value from `"moment_based"` to the more descriptive `"spearman"`. Accepted values are now `Literal["none", "spearman", "amn"]`. - `AFEstimationOptions.initialization_strategy` default: `"moment_based"` → `"amn"`. Same rename; accepted values are `Literal["constant", "spearman", "amn"]`. - `get_moment_based_start_params` renamed to `get_spearman_start_params`. When `"amn"` is selected: - `chs.get_maximization_inputs` runs `estimate_amn` on the dataset and overlays its parameter estimates onto the template, falling back to Spearman seeds for entries AMN doesn't touch (mixture weights, initial Cholesky diagonals). - `estimate_af` runs `estimate_amn` once upfront, merges the result with any user-supplied `start_params` (user values win on overlap), and switches the per-period MLE to the `"constant"` defaults so the within-period Spearman pre-pass is skipped (AMN's values are already in the optimizer's neighbourhood). Performance note: running the full AMN three-stage estimator is non-trivial on small datasets (a few seconds even for a 2-period skillmodels test model). Test fixtures `MODEL2` and `SIMPLEST_AUGMENTED_MODEL` therefore opt into `start_params_strategy="spearman"` explicitly so the CHS / AF test plumbing stays fast; the public `EstimationOptions()` default remains `"amn"`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`skillmodels.amn.mixture_em.fit_mixture_em` uses `sklearn.mixture.GaussianMixture` as its Stage 1 engine, and the `amn` package's `__init__.py` re-exports `estimate_amn` (which transitively imports `mixture_em`). The CI tests-cpu environment was missing scikit-learn, so collection failed on all three runners (macOS / Windows / Linux). Adds scikit-learn to both PyPI and Pixi dependency tables; the regenerated lock pulls scikit-learn 1.8.0 on all supported platforms. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`scikit-learn` is now a hard `skillmodels` dependency (used by `amn.mixture_em.fit_mixture_em`). Mirrors the addition in `pyproject.toml` across the three deployment artefacts -- CPU conda env, CUDA-12 conda env, and pip-only requirements -- so CBS deployments that bootstrap from these files don't hit `ModuleNotFoundError` at import time. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
AMN Stage 3 (`simulate_and_regress`) only supports linear, log_ces, and log_ces_with_constant transitions. When a model uses translog or a `@register_params`-decorated user transition function, `estimate_amn` raises `NotImplementedError`. With AMN as the default start-value strategy, this turned previously-passing CHS / AF tests (`test_af_estimate_with_translog`, `test_af_estimate_with_register_params_user_transition`, `test_af_joint_halton_recovers_sigma_prod_with_chain_link`) into regressions. Both `estimate_af` and `chs.get_maximization_inputs` now catch `NotImplementedError` from `estimate_amn`, emit a RuntimeWarning, and fall back to the cheap per-period Spearman seeds. AF additionally swaps `initialization_strategy="amn"` for `"spearman"` so the per-period MLE still benefits from data-driven starts. Also drops a `# ty: ignore[unresolved-import]` on `from sklearn.mixture import GaussianMixture` now that scikit-learn is a declared dependency. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replaces the previous `NotImplementedError`-then-fall-back hack with proper support. Specialised fitters stay for the cases where they pay off: closed-form OLS for `linear` and softmax-constrained Levenberg- Marquardt for `log_ces` / `log_ces_with_constant` (keeps gammas on the simplex). Everything else -- `translog`, `robust_translog`, `linear_and_squares`, `log_ces_general`, and any user `@register_params`-decorated transition -- now flows through a generic NLS path that calls the transition function directly via `jax.vmap`. Concretely: - `_resolve_transition_callable` looks up the built-in function from `skillmodels.common.transition_functions` for known names, or wraps the user's raw callable via a new `_make_user_transition_callable` helper (a Stage-3 mirror of AF's `_wrap_registered_transition_function`). - `_fit_generic_nls` jit-compiles a vmapped predictor, then runs `scipy.optimize.least_squares` with sensible defaults (phi/rho seeded at 0.5, CES-shaped functions get uniform-share gammas). - `simulate_and_regress` now takes `model_spec` so the user-callable lookup has access to `model_spec.factors[f].transition_function`. Removes the temporary `try/except NotImplementedError -> Spearman fallback` in `estimate_af` and `chs.get_maximization_inputs`: AMN now handles every transition, so the fallback is dead code. Test coverage: a new `test_simulate_and_regress_handles_translog` exercises the generic NLS path; the previously-regressing `test_af_estimate_with_translog`, `test_af_estimate_with_register_params_user_transition`, and `test_af_joint_halton_recovers_sigma_prod_with_chain_link` all pass without the fallback. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Same-step equality groups (whose members all live in one AF transition
step) are now forwarded verbatim to that step's `om.minimize`, in
addition to the existing cross-period forward-propagation that pins
later-period members to earlier-period estimates via `fixed_params`.
Adds `filter_within_step_constraints` and `reconcile_start_to_equality`
helpers in `common/constraints.py`; the latter averages each group's
current values before optimization so `om.minimize` doesn't reject the
start point with `InvalidParamsError`.
Needed for `sigma_inv_t == sigma_meas_inv_1_{t+1}` and similar same-step
identification constraints, which `_propagate_equality_groups` cannot
enforce (no anchor estimate exists when both members are in the same
step).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`ConditionalDistribution.samples_per_component` is kept only for the posterior-state summary stats (`MixtureComponent.mean`, `chol_cov`) that downstream `posterior_states` / inference consume; the transition likelihood rebuilds the chain on-demand from `chain_links` (see `_rebuild_chain_at_period`). The docstring already noted it "may use a smaller Halton count than the likelihood's `n_halton_points`" but the option wasn't wired through. Add `AFEstimationOptions.n_halton_points_posterior_summary` (default 256). `_extract_conditional_distribution` slices `nodes` to that count for `samples_per_component` construction; `_chain_one_component` propagates the resulting smaller `prev_sample` correctly (loop bound is `prev_sample.shape[0]`, not the full joint-Halton dim). Effect: at N=50k, T=5, n_halton=10k the persistent chain-replay tensor shrinks from 5*10000*50000*3*8 ~= 60 GB to 5*256*50000*3*8 ~= 1.5 GB. Likelihood values are unchanged (the path that consumes the full halton count is `_rebuild_chain_at_period`, which doesn't touch `samples_per_component`). Tests: 27 AF estimate + equality propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Follow-up cleanup to f612949 and 15ef656 from code-simplifier: * `common/constraints.py`: extract `_equality_constraint_loc(c)` so `filter_within_step_constraints` and `reconcile_start_to_equality` share one definition of what a `select_by_loc`-style `EqualityConstraint` looks like. Trims 5 lines of nested guards from each caller. * `af/initial_period.py`: replace `min(n_full, n) if n else n_full` with the `is None`-checked form so a legitimate `0` for `n_summary_halton` would not silently mean "use full halton". * `af/transition_period.py`: compress redundant comment + one-line the shape unpack inside `_chain_one_component`. Tests: 27 AF estimate + equality-propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Two follow-ups to f612949 / 15ef656 / af42b3f, both surfaced by a multi-agent review of PR #89: * `af/estimate.py`: the AMN-rebuild path of `AFEstimationOptions` (used when `initialization_strategy="amn"`) enumerated every field explicitly and forgot the two newly-added ones. A caller passing `keep_conditional_distributions=False` together with the default AMN initialization had their flag silently reverted to True, re-introducing the device->host OOM the flag was added to avoid. Add both `keep_conditional_distributions` and `n_halton_points_posterior_summary` to the reconstruction. * `af/types.py`: validate that `n_halton_points_posterior_summary >= 1` in `AFEstimationOptions.__init__`. The previous code path (`min(n_full, n)` when `n is not None`) accepted 0, which would produce `samples_per_component` tensors of shape `(0, n_obs, n_state)` and NaN `MixtureComponent.mean`. AFEstimationOptions is a user-facing boundary, so raising here matches the project's "validate at boundaries" stance. Tests: 27 AF estimate + equality-propagation tests pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…ansition out of chs/ Three estimator-agnostic utilities were physically located under skillmodels.chs.* but were already being imported from outside CHS (AF + AMN posterior_states; common/visualize_transition_equations; common/simulate_data), which made common/ depend on chs/ and muddled the package boundary. * create_state_ranges -> new skillmodels.common.state_ranges. Pure DataFrame utility. AF/AMN posterior_states + the test suite now import from common; CHS-internal callers updated. Dropped from skillmodels.chs.__init__ (no longer CHS-specific). Top-level re-export preserved for now (the wider __init__ cleanup is a separate follow-up). * anchor_states_df -> new skillmodels.common.anchoring. Operates on a (obs x period x factor) DataFrame via the per-period (scale, offset) pair from ModelSpec.anchoring; no CHS algorithms involved. The CHS get_filtered_states caller and simulate_data now import from common. * apply_anchored_transition -> new skillmodels.common.transitions. Extracted from CHS's transform_sigma_points: the anchor -> transition -> unanchor pipeline that all three estimators need. The CHS UKF retains a thin sigma-points-shape wrapper that flattens (n_obs, n_mixtures, n_sigma, n_fac) to (N, n_fac) for the shared core. simulate_dataset now calls the common helper directly, dropping the unnecessary (1, 1, n_obs, n_fac) sigma-points reshape it used to do. Tests: 185 passed across the 5 affected test files; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`plot_residual_boxplots`, `plot_likelihood_contributions`, and `decompose_measurement_variance` previously hid a CHS-specific `get_maximization_inputs` / `get_filtered_states` call behind the `data` argument, which made them silently CHS-only. Drop the `data` parameter and require the caller to pass a pre-computed DataFrame (`residuals`, `contributions`, `filtered_states`). This decouples the common/ functions from CHS so AF and AMN callers can use them unchanged, and matches the cross-estimator dispatch pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Make CHS-specific tuning parameters explicit. `EstimationOptions` exposed fields that are conceptually CHS-only (n_mixtures, sigma_points_scale, clipping_*), so callers were silently assuming CHS even when working with AF or AMN. Rename the class and the corresponding ModelSpec/ProcessedModel attribute and builder (`with_estimation_options` → `with_chs_estimation_options`) so the CHS scope is visible at every call site. Re-export from skillmodels.chs alongside the other CHS entry points. The class still lives in skillmodels.common.types because process_model() reads `bounds_distance` to build common endogenous-factors-info; a future split into truly-common fields plus CHS-specific extension can move it physically. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The top-level `skillmodels` namespace previously re-exported every public entry point across CHS, AF, AMN, and common helpers. This hid which estimator each function belonged to (e.g. `get_maximization_inputs` is CHS-only, `estimate_amn` is AMN-only) and conflicted with the new common/chs/af/amn subpackage split. Restrict the top-level imports to the four estimator-agnostic model_spec building blocks (`ModelSpec`, `FactorSpec`, `AnchoringSpec`, `Normalizations`). Everything else moves to its subpackage: * `skillmodels.chs` — `get_maximization_inputs`, `get_filtered_states`, `CHSEstimationOptions` * `skillmodels.af` / `skillmodels.amn` — already re-exported * `skillmodels.common.diagnostic_plots` — plot helpers * `skillmodels.common.variance_decomposition` — variance helpers * `skillmodels.common.simulate_data` — simulation helpers * `skillmodels.common.state_ranges` — `create_state_ranges` Update internal tests, docs notebooks, and the model-specs how-to to use the new paths. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Three coupled changes that finish the API/architecture cleanup: 1. Lift `n_mixtures` to `ModelSpec.n_mixtures: int = 1`. It is a structural property of the latent-mixture model (used by both CHS Kalman and AMN moment-init via Dimensions), not an estimation-tuning knob. 2. Physically move `CHSEstimationOptions` from common/types.py to the new chs/options.py. The class never depended on common types; only common code's casual use of `bounds_distance` kept it there for layering reasons. 3. Drop `chs_estimation_options` from `ModelSpec` and `ProcessedModel`. CHS callers now pass `chs_options` as a keyword argument to `get_maximization_inputs(...)`. This matches `estimate_af(..., af_options=...)` and `estimate_amn(..., amn_options=...)` -- the three estimators are now symmetric in how they take their tuning parameters. `get_constraints` and `_get_constraints_for_augmented_periods` gain an explicit `bounds_distance` parameter; the field is removed from `EndogenousFactorsInfo`. Test fixtures move CHSEstimationOptions out of MODEL2/ SIMPLEST_AUGMENTED_MODEL into sibling `*_CHS_OPTIONS` constants; tests that exercise CHS thread them through. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Add `AFEstimationOptions.optimizer_backend: Literal["optimagic", "jaxopt"] = "optimagic"`. When set to "jaxopt", each period's MLE runs `jaxopt.LBFGSB` on the on-device parameter vector instead of crossing host<->device once per iteration through `optimagic`. The jaxopt path supports `FixedConstraintWithValue` (pinned values from normalisations + user `fixed_params`) plus parameter bounds. Probability and equality constraints are out of scope -- the likelihood does not see them on-device because optimagic's constraint folding happens above the jit boundary. Models with those (log_ces transitions, cross-section equalities) keep using the optimagic backend; the jaxopt wrapper raises NotImplementedError with a clear hint. `optimizer_options` is forwarded directly to `LBFGSB(**...)` for the jaxopt backend; relevant keys are `maxiter`, `tol`, `history_size`. The `optimizer_algorithm` field is ignored when the jaxopt backend is selected (jaxopt always uses L-BFGS-B). Tests: * Unit tests for `minimize_with_jaxopt` (smoke quadratic, pinned values, unsupported-constraint rejection). * End-to-end parity test: linear single-factor AF estimation with `optimizer_backend="optimagic"` vs `"jaxopt"` produces matching log-likelihoods and free loadings. * Negative test: log_ces model with `optimizer_backend="jaxopt"` raises NotImplementedError. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Refresh the documentation to reflect what landed on the af-estimator branch: * `index.md` — replace the CHS-only public-API narrative with the three-estimator overview (chs/af/amn subpackages), list the top-level model_spec re-exports, and call out the optional jaxopt backend for AF. * New `explanations/architecture.md` — map the common/chs/af/amn subpackage split, describe how each estimator reuses `process_model` and the canonical params index, and state the layering rule (common/ never imports from chs/af/amn). * New `how_to_guides/how_to_estimate_af.md` — minimal example using `estimate_af`, the `optimizer_backend` choice (optimagic vs jaxopt) with a decision table, the `initialization_strategy` knob, and the current anchoring/ endogenous-factor support. * `model_specs.md` — drop the stale `chs_estimation_options=CHSEstimationOptions()` kwarg from the `ModelSpec(...)` literal (it no longer exists; CHS options are passed at call time). Point to the AF how-to for the matching call-site pattern. * `tutorial.ipynb` — update the MODEL2 import to also pull MODEL2_CHS_OPTIONS and pass it to `get_maximization_inputs`, matching the test fixtures. * `myst.yml` — add the new pages to the toc. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`AFEstimationOptions.optimizer_backend` gains a new literal `"auto"`, which is now the default. The resolution happens inside `estimate_af` via the new private `_resolve_optimizer_backend` helper: * If a JAX GPU is visible (`any(d.platform == "gpu" for d in jax.devices())`) AND the model is jaxopt-compatible (no `log_ces*` transitions -- they introduce probability constraints jaxopt can't fold -- and no caller-supplied `constraints`, which would arrive as equality constraints), use `"jaxopt"`. * Otherwise fall back to `"optimagic"`. Explicit `"optimagic"` / `"jaxopt"` requests are honoured as-is. Also rolls in the visualize_* Camp 2 refactor, the CNLSY CSV vendoring, and the pandas-PerformanceWarning suppression in pytest filterwarnings. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The tutorial exercises CHS + AF + AMN on the same `fixed_params` and
flushed out a cluster of latent bugs.
1. `select_by_loc(params, single_tuple)` returned a row Series indexed
by column names. When optimagic's pytree machinery flattened it,
all three of (`value`, `lower_bound`, `upper_bound`) were cast to
`int` -- `±inf` collapsed to the int64 sentinel, producing a
duplicate that indexed off the end of `param_names` and raised
`IndexError: list index out of range` from
`optimagic.parameters.process_selectors._fail_if_duplicates`. Both
`common/constraints.select_by_loc` and the deliberate near-copy in
`common/transition_functions.select_by_loc` now project the result
down to the `value` column before returning.
2. CHS's `get_maximization_inputs` now calls a new
`_project_to_probability_constraints` after the AMN / Spearman
seeding step: every `ProbabilityConstraint` whose entries don't
sum to one gets its free members rescaled to fill
`1 - sum(pinned values)`. Pinned entries stay pinned. Without this,
AMN-seeded gammas were rejected by `check_constraints_are_satisfied`
("Probabilities do not sum to 1") and Spearman-seeded uniform
weights gave the wrong target.
3. AMN's `_apply_overrides` was using `MultiIndex.union` to merge
`fixed_params` / `start_params` into the combined `all_params`.
`union` silently drops every level whose name differs across the
two operands -- and user overrides come in keyed by the public
`period` level while AMN's combined frame uses `aug_period`. The
resulting `all_params` ended up with `None`-named levels, which
then broke `params.loc[...]` callers like
`decompose_measurement_variance`. New `_align_index_names` helper
re-stamps the override's level names to match the target before
the union so the tuples stay identical and the names survive.
4. `decompose_measurement_variance` was hard-coded to read
`aug_period` out of `params.loc["loadings"].reset_index()`, which
only works for CHS params. AF / AMN params expose `period`. The
`rename` block now accepts either spelling.
The pre-existing `identity_constraints_log_ces*` signature
mismatch (positional `factors` vs CHS's `(factor, aug_period,
all_factors)` dispatch) is consolidated into the no-op form here
together with the matching test rewrites.
Pre-commit-config picks up the `nbstripout` exclude needed by the
follow-on tutorial-render commit.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop the cross-factor gamma pins from the skills CES (all five gammas now free under the simplex constraint). With the projection step landed in the previous commit, AMN seeding feeds the optimizer a feasible start; CHS converges (log-likelihood -39620.6) and the optimizer drives non-skills gammas toward zero where the data allows (period-0 MC stays ~0.11, period-1 investment ~0.10). Switch the AF cell to `optimizer_algorithm="scipy_lbfgsb"` so the notebook runs in environments without `fides` (the default fides algorithm wasn't part of the pixi env). All 14 code cells execute; the notebook ships pre-rendered. `pyproject.toml` per-file-ignores for `**/*.ipynb` now waive `ANN` and `PD010` so tutorial helpers (which take `params`, `period`, `meas` positionally and pivot small presentation tables) don't need full annotations or `pivot_table` conversion. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…t safety Addresses every in-scope finding from the post-push review of 1de00a6..074cc83, plus two pre-existing bugs the review surfaced and one jaxopt-incompatibility the skane-struct-bw pytask run hit. Refactor -------- * New `common/selector.py` hosts `select_by_loc` and `align_index_names` as a single dependency for both `common/constraints.py` and `common/transition_functions.py`. The previous near-copy in `transition_functions.py` (kept to dodge a circular import) becomes a plain re-export and the lazy `import pandas as pd` workaround goes away. * `_project_to_probability_constraints` and `_collect_fixed_locs` move out of `chs/maximization_inputs.py` into `common/constraints.py` as the public `project_to_probability_constraints` / `collect_fixed_locs`. They are general constraint-reconciliation primitives, not CHS-specific. * `amn/estimate.py` drops its private `_align_index_names` in favour of the shared `common.selector.align_index_names`. CHS / AMN symmetry ------------------ * `chs/maximization_inputs._build_fixed_constraints_from_params` now normalises the user override's level names before `params_index.intersection(...)`. Previously a `fixed_params` frame keyed by the public `period` level silently produced an empty intersection (`MultiIndex.intersection` returns nothing when level names diverge) and user fixes vanished without a warning. * `collect_fixed_locs` learns about the `pd.MultiIndex` arm of `FixedConstraintWithValue.loc`. AF jaxopt safety ---------------- * `af.estimate._filter_step_constraints` strips non-`FixedConstraintWithValue` entries from the per-step constraint list when `optimizer_backend="jaxopt"` and emits a single `RuntimeWarning`. Without this, any user-supplied `EqualityConstraint` reached `_check_constraints_supported` and triggered `NotImplementedError`. Cross-period equalities are still propagated via `_extract_equality_groups` / `_propagate_equality_groups`; only within-step equalities are lost, and the warning surfaces that. Pre-existing bugs surfaced by the review ---------------------------------------- * `common/process_model._augment_periods_for_endogenous_factors` reconstructs the augmented `FactorSpec` and was dropping `has_production_shock` and `has_initial_distribution`, both of which default to `True`. A model that set either flag to `False` silently saw it flip back to `True` whenever endogenous-period augmentation ran. Now forwards both fields, with a regression test. * `af/transition_period.py:112` carried a stale "For now, use the first non-constant factor's transition for the combined function" comment that describes neither what the surrounding code does nor why; removed. Tests ----- New `tests/test_selector.py` covers the four small primitives in isolation. Plus three regression tests on the public APIs: * `test_estimate_amn_honors_fixed_params_keyed_by_period` * `test_get_maximization_inputs_accepts_fixed_params_keyed_by_period` * `test_compute_variance_decomposition_with_period_level_params` 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
… concepts New how-tos ----------- * `how_to_estimate_amn.md` -- minimal example, a synthetic 2-mixture DGP that's the smallest case where AMN's non-Gaussian latent fit beats CHS, tuning knobs by stage, inference, and a "what AMN does not (yet) do" punch list. * `how_to_compare_estimators.md` -- picks up where `tutorial.ipynb` leaves off and quantifies uncertainty across all three estimators: CHS analytic sandwich (`estimate_ml`), AF score bootstrap (`compute_af_standard_errors`), AMN cluster bootstrap (`compute_amn_standard_errors`). Includes a posterior-trajectory overlay across estimators and a short "when the estimators disagree" diagnostic section. Touch-ups --------- * `names_and_concepts.md` -- replaces the single legacy `EstimationOptions` paragraph with one section per estimator (`CHSEstimationOptions`, `AFEstimationOptions`, `AMNEstimationOptions`). Notes that `n_mixtures` lives on `ModelSpec` because it changes the model, not the optimizer. * `reference_guides/transition_functions.md` -- one-line note that the same transition functions work for all three estimators, plus the current caveat that AMN's Stage 3 doesn't yet honour `@register_params` custom transitions. * `explanations/architecture.md` -- adds `common/selector.py` (`select_by_loc`, `align_index_names`) and the new public helpers `collect_fixed_locs` / `project_to_probability_constraints` in `common/constraints.py` to the file layout, matching the post-review refactor. * `myst.yml` -- TOC additions for the two new how-tos. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The Marvin AF-jaxopt run (job 25980258) failed on every single sim
with:
JaxRuntimeError: INVALID_ARGUMENT: ... Reduction function's
accumulator shape at index 0 differs from the init_value shape:
s32[] vs s64[], for instruction %scatter ...
metadata={op_name="jit(update)/jit(argsort)/sort"}
Failed after permutation_sort_simplifier
jaxopt's `LBFGSB.update` calls `jnp.argsort` internally. With x64
off at jaxopt import time, the resulting sort emits int32 indices,
which then scatter into an int64 operand the rest of the optimizer
builds. XLA's `permutation_sort_simplifier` pass rejects that
mismatch on JAX >= 0.10 (the cuda13 wheel on Marvin uses 0.10).
Pre-0.10 jaxes accepted the mixed shapes; 0.10 tightened the verifier.
Every CHS / AF / AMN entry point already calls
`jax.config.update("jax_enable_x64", True)` inside its function
body, so the package has effectively always assumed x64. Moving
the flip to `skillmodels/__init__.py` (and setting
`JAX_ENABLE_X64=1` in the environment before `import jax`) makes
it apply at import time -- which is what `jaxopt`'s
module-level jit kernels need. `af/jaxopt_backend.py` gets the
same belt-and-suspenders guard for callers that import it
directly without first going through `skillmodels/__init__.py`.
Net effect: no behaviour change for CHS / AF-optimagic / AMN
callers (x64 was already on by the time they ran). The jaxopt
path now runs cleanly on JAX 0.10 / cuda13. Local AF jaxopt
test suite (30 tests) still passes.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Stacked on the af-estimator branch (`2206212` series). Mirrors the
pattern in pylcm PR #355: a per-exception `BeartypeConf` plus a
`beartype_init` class decorator routes parameter-type violations at
every documented entry point through a skillmodels-specific
exception class, so callers can write narrowly-scoped `except`
clauses against a stable hierarchy rather than catching beartype's
framework exception.
Layout
------
* `src/skillmodels/exceptions.py` -- six `TypeError` subclasses,
organised by perimeter (`ModelSpecInitializationError`,
`OptionsInitializationError`, `EstimationCallError`,
`InferenceCallError`, `SimulationCallError`,
`DiagnosticsCallError`), all inheriting from a common
`SkillmodelsInputError` for callers that want to catch the whole
hierarchy in one go.
* `src/skillmodels/_beartype_conf.py` --
`_conf(exc)` builds a `BeartypeConf` with
`violation_param_type=exc`, `strategy=BeartypeStrategy.On` (full
O(n) container scan; entry points are called rarely compared to
the JIT-compiled hot path each one kicks off), and
`is_pep484_tower=True`. `beartype_init(conf)` is a class decorator
that wraps only `__init__` so non-public method-level annotation
drift on instance methods does not surface at construction time.
Decoration sites
----------------
* `@beartype_init(MODEL_SPEC_CONF)` on `FactorSpec`, `AnchoringSpec`,
`ModelSpec`, `Normalizations`.
* `@beartype_init(OPTIONS_CONF)` on `CHSEstimationOptions`,
`AFEstimationOptions`, `AMNEstimationOptions`.
* `@beartype(conf=ESTIMATION_CONF)` on `get_maximization_inputs`,
`get_filtered_states`, `estimate_af`, `estimate_amn`,
`get_af_posterior_states`, `get_amn_posterior_states`.
* `@beartype(conf=INFERENCE_CONF)` on
`compute_af_standard_errors`, `compute_amn_standard_errors`.
* `@beartype(conf=SIMULATION_CONF)` on `simulate_dataset`,
`simulate_policy_effect`.
* `@beartype(conf=DIAGNOSTICS_CONF)` on
`decompose_measurement_variance`,
`summarize_measurement_reliability`,
`plot_residual_boxplots`, `plot_likelihood_contributions`,
`create_state_ranges`, `plot_correlation_heatmap`,
`get_measurements_corr`, `get_quasi_scores_corr`,
`get_scores_corr`, `univariate_densities`,
`bivariate_density_contours`, `bivariate_density_surfaces`,
`combine_distribution_plots`, `get_transition_plots`,
`combine_transition_plots`.
Side effects of perimeter-only validation
-----------------------------------------
* `_check_measurements`'s type-shape arm in
`common/check_model.py` is now dead code: the
`tuple[tuple[str, ...], ...]` annotation on
`FactorSpec.measurements` makes beartype reject every malformed
measurement structure at construction time. The function is kept
(the report aggregator might still surface non-type issues a
beartype container scan can't see), but the corresponding two
tests in `tests/test_check_model.py` are rewritten to assert
`ModelSpecInitializationError` at `FactorSpec(...)` time
instead of asserting a soft message in the aggregator output.
* `tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_value`
now asserts `OptionsInitializationError` from beartype's
`Literal` check (which fires before
`AFEstimationOptions.__post_init__`'s manual ValueError).
* `tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_results`
now asserts `EstimationCallError`; the prior body-level
`"only one of"` ValueError is still in place but is unreachable
from this fixture, which passes the same AMN result to both
parameters and so trips the type guard on `af_result` first.
* `chs/filtered_states.py` imports `AFEstimationResult` /
`AMNEstimationResult` at runtime rather than under
`TYPE_CHECKING` so beartype can resolve the annotation; ruff's
TC003 autofix had been silently unforwarding the string forward
refs.
Verification
------------
* `pixi run -e tests-cpu pytest tests/ -q -k "not long_running"` --
529 passed, 1 deselected (same count as before this commit; no
regressions).
* `pixi run ty` -- clean.
* `prek run --all-files` -- clean.
Out of scope (follow-up PRs)
----------------------------
* Whole-package activation via `beartype.claw.beartype_package("skillmodels")`
in `tests/conftest.py`. That probe would surface internal-helper
annotation drift the same way pylcm's part-3 PR will, and is left
for a separate review.
* AGENTS-level conventions documentation. The perimeter is in place;
the rule for where to put the next decorator is "wherever the
signature is documented to the user" -- to be expanded once the
pattern has settled.
🤖 Generated with [Claude Code](https://claude.com/claude-code)
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The previous fix (2206212) tried to address this by enabling x64 before `import jaxopt`. The reproducer on sonny (jax 0.10.0, jaxopt 0.8.5) shows that's insufficient: even with x64 on before any jaxopt code runs and float64 inputs throughout, `LBFGSB.update`'s jit-compiled `jnp.argsort` still emits an s32 reduction accumulator while the surrounding scatter operand is built as s64. XLA's `permutation_sort_simplifier` HLO pass rejects the mismatch with `INVALID_ARGUMENT: Reduction function's accumulator shape at index 0 differs from the init_value shape: s32[] vs s64[]`. Disabling just `permutation_sort_simplifier` via `XLA_FLAGS` fixes the crash, keeps every other XLA optimisation intact, and is a no-op on JAX < 0.10 (the pass doesn't exist there). The flag must be set before `import jax` because XLA reads `XLA_FLAGS` once at backend init. Applied in two places: - `skillmodels/__init__.py`: the primary entry point. Appends to any pre-existing `XLA_FLAGS` so user flags aren't clobbered. - `skillmodels/af/jaxopt_backend.py`: belt-and-suspenders for direct module imports that skip the package init. The previous comment block tying the bug to "x64 off at import time" was wrong about the root cause; replaced with the actual XLA pass explanation. The `JAX_ENABLE_X64=1` setting is retained because the AF pipeline assumes float64 throughout. Verified end-to-end on sonny (jax 0.10): minimum jaxopt repro that previously crashed now succeeds. Local jaxopt backend tests (7) and full local suite (485 tests, jax 0.9) still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
The perimeter decorators (PR #90 / commit d7d29ea) covered only public entry points. The claw activation in `tests/conftest.py` extends type enforcement to every annotated callable in the package during the test run, catching annotation drift on internal helpers that would otherwise silently flow through. Configuration: - `is_pep484_tower=True` to mirror the perimeter conf so `int` satisfies `float`-typed parameters. - `claw_skip_package_names=("skillmodels.chs.qr",)` because JAX's `@custom_jvp` decorator stores the secondary `.defjvp` setter on the wrapped object; beartype.claw's wrapping strips it. Annotation drift fixes (sources of truth, not type-system theater): * `FixedConstraintWithValue` moved to its own module `common/fixed_constraint.py`. `transition_functions.py` previously imported it under `if TYPE_CHECKING:` to avoid a circular import with `constraints.py`; beartype.claw can't resolve those forward refs at decoration time. The leaf type now lives where both modules can pull it without a cycle. * Same TYPE_CHECKING removal for `ModelSpec` in `af/types.py` and `amn/types.py`. * JAX-traced helpers (`_at_node`, `_chain_one_component`, `_compute_investment`, kalman / likelihood entry points, transition pipeline plumbing): annotations relaxed to accept `Array | np.ndarray` / `float | Array` / `int | np.integer` where the runtime contract is genuinely mixed. JAX vmap traces ints as `BatchTracer`; numpy and jax arrays interconvert freely through these signatures. * `MixtureComponent`, `ConditionalDistribution`, `ChainLink` dataclass fields: accept `np.ndarray` alongside `jax.Array` since estimators fill them with both. * `TransitionInfo.param_names`: now built with explicit `tuple()` conversion at the boundary in `process_model.py` so the `MappingProxyType[str, tuple[str, ...]]` annotation actually holds. * `get_has_endogenous_factors` now casts the pandas `.any()` result to `bool` instead of relying on a `# ty: ignore`. * `NDArray[np.floating[Any]]` (which beartype doesn't accept as a dtype hint) replaced with `NDArray[np.float64]` in `chs/process_debug_data` and `common/visualize_factor_distributions`. * `NDArray[np.floating]` in `common/simulate_data` widened to `NDArray[np.floating] | Array`. * Internal duck-typed validators (`_check_anchoring`, `_process_factors`) re-typed as `Any` with `# noqa: ANN401` and an inline comment; they exist precisely to take partially-built objects. * `_aug_periods_from_period`: `dict[int, int]` → `Mapping[int, int]` (production passes a `MappingProxyType`). Tests: - `tests/test_check_model.py`: re-add `# ty: ignore` on the two `FactorSpec(measurements=...)` calls that intentionally pass `list` where `tuple` is required; they verify the beartype perimeter catches the shape error. - `tests/test_transition_functions.py::test_constant`: rewritten to pass real JAX arrays now that the claw type-checks `constant`. - Stale `# ty: ignore[invalid-argument-type]` directives stripped from `test_check_model.py`, `test_correlation_heatmap.py`, `test_process_debug_data.py` (and one in `simulate_data.py`) — the annotations they were silencing have been relaxed. Verification: 495 tests pass with claw enabled (`pixi run -e tests-cpu tests`); `pixi run ty` clean; `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
CI run 25868871644 surfaced 14 failures in `tests/test_af_inference.py` that local runs missed (the file is unmarked but the local sweep excluded similarly-shaped tests via `-k "not long_running and not end_to_end"`). All failures had the same root cause as the bulk of the prior claw-activation diff: AF transition / chain helpers were annotated with strict `jax.Array` parameters, but the runtime path through `af.inference` constructs `prev_distribution`, chain link inputs, and chol/mean blocks from `np.ndarray`. Beartype rejected them under the now-active claw. Relaxed sites: - `af.likelihood.af_per_obs_loglike_transition` / `af_loglike_transition` / `_integrate_transition_chain`: `prev_distribution` widened from `dict[str, Array]` to `Mapping[str, Array | np.ndarray]`. `Mapping` (covariant) lets callers pass `dict[str, Array]` without an explicit cast. - `af.likelihood._map_over_obs`: `*xs: Array` → `*xs: Array | np.ndarray`. - `af.likelihood._integrate_transition_single_obs`: `obs_cond_weights`, `obs_cond_means`, `cond_chols` widened to `Array | np.ndarray`. - `af.likelihood._rebuild_chain_at_period`: `initial_mean`, `initial_chol` widened to `Array | np.ndarray`. Internal `theta` bound through `jnp.asarray(...)` so downstream `_compute_investment` still sees an `Array`. - `af.likelihood._compute_investment`: `inv_eq_params`, `inv_sds` widened (covered earlier; ty-narrowing followed naturally). Verification: - `pixi run -e tests-cpu pytest tests/test_af_inference.py` — 14 / 14 pass (was 1 failed + 13 errors). - `pixi run ty` clean. - `prek run --all-files` clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Previously stopped on `max|projected_grad| < tol` only. That's not how scipy_lbfgsb stops in practice: scipy stops on EITHER `gtol_abs` OR a relative function-value decrease `ftol_rel`. On the skill-formation likelihoods used in the Monte Carlo benchmarks, the loglikelihood goes locally flat before the gradient does, so scipy's ftol channel fires ~100% of the time and gradient-norm < 1e-5 is essentially never the actual stopping criterion in production. Without the ftol channel, jaxopt would grind down the gradient while scipy declared success at the same point — a fake apples-to-oranges asymmetry that made the AF-jaxopt vs AF-optimagic timing comparison meaningless. Implementation: drop jaxopt's built-in `run()` and drive the solver through an explicit `init_state` + `update` loop with the same gtol-OR-ftol stop. Default values now match scipy_lbfgsb's defaults (`gtol_abs=1e-5`, `ftol_rel=2.22e-9`, `maxiter=15000`). The wrapper accepts both canonical scipy keys (`convergence_gtol_abs`, `convergence_ftol_rel`, `stopping_maxiter`) and the historical jaxopt keys (`tol`, `maxiter`) so the same `optimizer_options` dict works for either backend. This makes the two LBFGSB implementations stop on byte-identical rules; the only remaining differences are internal (line search, step acceptance, curvature-pair filtering) — which is the comparison that's actually interesting. Verified: 7 jaxopt_backend tests still pass; ty clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
`AFEstimationOptions.__post_init__` runs `ensure_containers_are_immutable`
on the user's `optimizer_options` dict, which recursively wraps every
nested dict in `MappingProxyType`. The two `om.minimize(..., **dict(
af_options.optimizer_options))` call sites in `af/initial_period.py`
and `af/transition_period.py` only unwrap the outer layer; an
`algo_options={"convergence_gtol_abs": 1e-5, ...}` user dict therefore
arrives at `om.minimize` as `algo_options=MappingProxyType(...)` and
trips optimagic's `isinstance(algo_options, dict)` check with
`ValueError: algo_options must be a dictionary or None`.
Surfaced on the Marvin 3-way Monte Carlo run where AF optimagic
failed 100% of sims with that exact ValueError; AF jaxopt and CHS
were unaffected (jaxopt's wrapper consumes simple top-level keys;
CHS's `om.minimize` call passes a plain dict directly).
Fix: add `to_plain_dict` in `common/types.py` (inverse of
`_make_immutable` — recursively unwraps MappingProxyType/tuple/
frozenset back to dict/list/set) and use it at both AF optimagic
call sites. The jaxopt path is unchanged because its `options` are
flat scalars, not nested.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Summary
af/subpackage implementing the Antweiler & Freyberger (2025) estimator as an alternative to the CHS Kalman filter.ModelSpecinterface — users switch estimator by callingestimate_af()instead ofget_maximization_inputs()+om.maximize().log_ces/linear/translogtransitions, endogenous factors via explicit investment equation.get_filtered_states()interface: passaf_result=for AF posterior states, omit for CHS filtered states.beartype.clawactivation in tests, and JAX 0.10 / cuda13 workarounds. See sections below.AF estimator (
src/skillmodels/af/)estimate_af(model_spec, data, af_options, start_params)→AFEstimationResult.ProbabilityConstraintforlog_cesgammas, satisfied at start values.I = β₀ + β₁θ + β₂Y + σ_I εfor endogenous factors.start_paramssupport: user-supplied starting values override heuristic defaults.get_filtered_states(model_spec, data, params, af_result=result)computes quadrature-based posterior means per individual / period.compute_af_standard_errors.Optimizer backends
AFEstimationOptions.optimizer_backendchooses how each period's MLE is solved:"optimagic"(default fallback):om.minimize(algorithm="scipy_lbfgsb", ...). Required when user equality / probability constraints exist (jaxopt can't fold those)."jaxopt":jaxopt.LBFGSB, run directly on device. Avoids the host↔device transfer optimagic incurs once per likelihood call."auto": pick"jaxopt"iff a JAX GPU is visible and the model is jaxopt-compatible (nolog_ces*transitions, no user-supplied constraints); otherwise"optimagic".The jaxopt wrapper now matches scipy_lbfgsb's gtol OR ftol stopping rule (drives the solver through an explicit
init_state+updateloop and checks‖projected_grad‖∞ < gtol_absOR(f_k − f_{k+1}) / max(|f_k|, |f_{k+1}|, 1) < ftol_relafter each step). The sameoptimizer_optionskeys (convergence_gtol_abs,convergence_ftol_rel,stopping_maxiter) work for either backend, so per-sim Monte Carlo timing benchmarks are byte-identical apart from internal LBFGSB mechanics.Beartype perimeter on user-facing API
Mirrors the pattern in pylcm PR #355: a per-exception
BeartypeConfplus abeartype_initclass decorator routes parameter-type violations at every documented entry point through a skillmodels-specific exception class, so callers can write narrowly-scopedexceptclauses against a stable hierarchy rather than catching beartype's framework exception.Exceptions (
src/skillmodels/exceptions.py)Six
TypeErrorsubclasses of a commonSkillmodelsInputError, organised by perimeter:ModelSpecInitializationError—FactorSpec,AnchoringSpec,ModelSpec,NormalizationsOptionsInitializationError—CHSEstimationOptions,AFEstimationOptions,AMNEstimationOptionsEstimationCallError—get_maximization_inputs,get_filtered_states,estimate_af,estimate_amn,get_af_posterior_states,get_amn_posterior_statesInferenceCallError—compute_af_standard_errors,compute_amn_standard_errorsSimulationCallError—simulate_dataset,simulate_policy_effectDiagnosticsCallError—decompose_measurement_variance,summarize_measurement_reliability,plot_residual_boxplots,plot_likelihood_contributions,create_state_ranges,plot_correlation_heatmap,get_measurements_corr,get_quasi_scores_corr,get_scores_corr,univariate_densities,bivariate_density_contours,bivariate_density_surfaces,combine_distribution_plots,get_transition_plots,combine_transition_plotsDecorator + config (
src/skillmodels/_beartype_conf.py)_conf(exc)—BeartypeConfwithviolation_param_type=exc,strategy=BeartypeStrategy.On(full O(n) container scan),is_pep484_tower=True.beartype_init(conf)— class decorator that wraps only__init__. Avoids surfacing non-public annotation drift on instance methods that has nothing to do with parameter validation at construction time.Whole-package
beartype.clawactivation in tests (tests/conftest.py)beartype.claw.beartype_package("skillmodels", conf=...)turns annotation-drift on internal helpers intoBeartypeCallHintParamViolationduring the test run.skillmodels.chs.qris excluded because JAX's@custom_jvpdecorator's secondary.defjvpattribute doesn't survive beartype's wrap. Activating the claw surfaced ~80 internal annotation drifts (Array / np.ndarray / int / Mapping / dict-vs-MappingProxyType / TYPE_CHECKING-only forward refs), all fixed in the commits that follow.Side effects
_check_measurements's type-shape arm incommon/check_model.pyis now dead code: thetuple[tuple[str, ...], ...]annotation onFactorSpec.measurementsmakes beartype reject every malformed measurement structure at construction time. The function is kept for non-type issues the container scan can't see; the two tests that previously asserted soft errors are rewritten to assertModelSpecInitializationErroratFactorSpec(...)time.tests/test_af_jaxopt_backend.py::test_optimizer_backend_rejects_unknown_valuenow assertsOptionsInitializationErrorfrom beartype'sLiteralcheck (which fires beforeAFEstimationOptions.__post_init__'s manualValueError).tests/test_amn_plot_harmonization.py::test_get_filtered_states_rejects_both_af_and_amn_resultsnow assertsEstimationCallError.chs/filtered_states.pyimportsAFEstimationResult/AMNEstimationResultat runtime rather than underTYPE_CHECKINGso beartype can resolve the annotation; ruff's TC003 autofix had been silently unforwarding the string forward refs.FixedConstraintWithValuemoved to its own modulecommon/fixed_constraint.pyto break a circular import (constraints.pyimportstransition_functions.py; the leaf type now lives where both modules can pull it without a cycle).JAX 0.10 / cuda13 workarounds
XLA_FLAGS=--xla_disable_hlo_passes=permutation_sort_simplifieris set at package import to bypass a JAX 0.10 XLA pass that mis-lowers theargsortinsidejaxopt.LBFGSB.update(emits an s32 reduction accumulator into an s64 scatter operand). No-op on JAX < 0.10.JAX_ENABLE_X64=1set at package import time so transitiveimport jaxoptsees x64 as the default integer width.Test plan
pixi run -e tests-cpu pytest tests/ -q -k "not long_running"— all green with beartype.claw enabledpixi run ty— cleanprek run --all-files— cleanpytest -m long_running— MODEL2 AF vs CHS comparison (both estimators optimised from same naive start values)🤖 Generated with Claude Code