Skip to content

numpyro/NUTS fit_map path is significantly slower than emcee #34

@richteague

Description

@richteague

Summary

After the fit_map emcee speedups (f23cf29, 867bc38, 215a45e, cumulatively ~16×), the mcmc='numpyro' path is now wall-clock-slower than mcmc='emcee' on typical fits, even though it is dramatically more sample-efficient (256× fewer samples to reach the same posterior resolution on the validation fit in REFACTORING_PLAN.md §5.1b).

The two emcee speedups (JIT closure + vmap'd batch ln-prob) compile the per-step log-probability into a single XLA dispatch that vmaps across walkers. numpyro can't share that optimisation: NUTS extends its trajectory until a U-turn, which is a per-chain condition, so vmap'ing across chains doesn't help.

Profile of a current fit

9-parameter HD163296 3D fit (docs/tutorials/tutorial_6_numpyro.ipynb), 500 warmup + 500 samples, single chain:

  • Mean leapfrog steps per NUTS iteration: 127
  • Median: 95
  • ~39 % of iterations hit the max_tree_depth=8 cap (256 leapfrog steps)
  • Each gradient evaluation costs ~2× a likelihood evaluation

So one numpyro sample does the gradient work of ~250 emcee evaluations.

Micro-benchmark (50 warmup + 50 sample, post-JIT-warm)

config wall
1 chain, max_tree_depth=8 (previous tutorial default) 40.1 s
1 chain, max_tree_depth=6 (new tutorial default) 23.1 s
4 chains, chain_method='sequential' 197.3 s
4 chains, chain_method='vectorized' 194.3 s

Findings

  • Capping max_tree_depth at 6 gives a ~42 % wall-time speedup with minimal posterior loss — tutorial 6 now recommends 6.
  • chain_method='vectorized' does not help. NUTS' adaptive tree length is per-chain; vmap'd chains must all run to the longest tree on each iteration, cancelling the vmap win. (vectorized is only a win for HMC with fixed L.)

Untried directions

  • target_accept_prob < 0.8 default — accept noisier steps, smaller trajectories. Trade-off: more divergences.
  • dense_mass=True — adapt a full covariance preconditioner. Could help the strongly-correlated z0/psi/r_taper/q_taper block, but the warmup adaptation is slower.
  • More aggressive image downsampling — _make_model cost scales with pixel count; high-SNR pixels dominate the posterior.
  • Audit the numpyro JIT cache: confirm the model is compiled exactly once per fit and not invalidated between iterations.
  • See whether a bounded-horizon scan over leapfrog steps is feasible (probably blocked by the U-turn termination condition, but worth checking).

Workaround

Stay on mcmc='emcee' (the default) for routine fits. Use mcmc='numpyro' when sample efficiency is the actual bottleneck (very expensive single likelihood, very long autocorrelation under emcee, or GPU-available).

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions