Add parallel associative-scan algorithms for the quasiseparable solver#269
Merged
Add parallel associative-scan algorithms for the quasiseparable solver#269
Conversation
cff229a to
6003864
Compare
The existing quasiseparable operations use jax.lax.scan with O(N) sequential depth, which serializes badly on accelerators. This adds jax.lax.associative_scan-based implementations with O(log N) depth, selectable via a parallel=True flag threaded from GaussianProcess through QuasisepSolver down to the QSM matmul/solve/cholesky/inv methods. - ops.py: parallel and sequential implementations of lower/upper matmul, lower/upper solve, cholesky, and symm_inv. The cholesky and symm_inv forward passes share a common Riccati associative scan. - core.py: QSM methods take parallel: bool = False and dispatch to ops. - solver.py: QuasisepSolver(..., parallel=True) uses the parallel path for factorization, triangular solves, and matrix products. - block.py: Block gains .mT, batched to_dense(), and batched matmul. This also fixes a pre-existing crash in LowerTriQSM.inv() on summed kernels. SquareQSM.inv remains sequential.
6003864 to
f4e651d
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Note
This builds on and supersedes #210 which just included the parallel matmul implementation.
This PR adds
jax.lax.associative_scan-based implementations of the core quasiseparable operations alongside the existingjax.lax.scanversions, and threads aparallel: boolflag from the GP layer down to select between them.Usage
or directly on the QSM types:
The default remains
parallel=False(sequentiallax.scan), so existing behavior is unchanged.Why
The sequential scans have O(N) depth, which serializes badly on GPUs. The associative-scan formulations have O(log N) depth at the cost of a constant factor more FLOPs, which is a much better fit for accelerators.
Rough wall-clock numbers on a single GPU at J=8:
On CPU the picture is mixed: the parallel matmul is roughly competitive with the sequential version at large N, but the parallel Cholesky is several times slower on CPU because each combine step does a small
linalg.solve. Soparallel=Trueis recommended for GPU/TPU andparallel=False(the default) for CPU.Math
The matmul and triangular-solve recurrences are affine in the carry,$f_n = a_n f_{n-1} + b_n$ , so they parallelize directly as a prefix scan over the monoid $(A, B) \bullet (A', B') = (A' A,\ A' B + B')$ . The Cholesky carry is a Riccati-type (quadratic-over-linear) update that does not fit the affine monoid. Reading each step as a Kalman predict-then-update identifies it with the associative filtering element of Särkkä & García-Fernández: each step is represented by a triple $(A, F, G)$ and composition follows from multiplying the corresponding $2J \times 2J$ Hamiltonian matrices, giving a combine that needs one $J \times J$ linear solve. The forward pass of $z_k = \ell_k^\top z_{k+1} \ell_k + B_k$ that uses the affine-style operator again.
SymmQSM.invcarries the same Riccati state, and its backward pass reduces to a linear conjugation recurrenceWhat's included
solvers/quasisep/ops.py: parallel implementations oflower_matmul,upper_matmul,lower_solve,upper_solve,cholesky, andsymm_inv, plus their sequential counterparts factored out ofcore.py. The Cholesky and symmetric-inverse forward passes share a common_riccati_scanhelper.solvers/quasisep/core.py: allmatmul/solve/cholesky/invmethods on the QSM classes takeparallel: bool = Falseand dispatch toops.py.solvers/quasisep/solver.py:QuasisepSolveracceptsparalleland uses it for the factorization, triangular solves, and matrix products.solvers/quasisep/block.py:Blockgains.mT, batchedto_dense(), and batched__matmul__/__rmatmul__. This also fixes a pre-existing bug whereLowerTriQSM.inv()(used incondition) failed on summed kernels.SquareQSM.invremains sequential — its asymmetric Riccati recurrence needs a larger associative operator and it isn't on any hot GP path.Tests
test_ops.pychecks each parallel op against its sequential reference on Matern32/Matern52 kernels.test_core.pyparameterizes the matmul/solve/cholesky tests overparallel={False, True}.test_solver.py::test_consistent_with_directis parameterized overparalleland gains a summed-kernel case to exercise theBlockpath end-to-end.