Skip to content

Commit

Permalink
Fix compatibilty issues (#59)
Browse files Browse the repository at this point in the history
* Fix compatibility issues and requirements files

* Update python version for testing

* Update python version for setup

* Update python version for setup

* Update python version for setup

* Bump version to 1.0.0

* Precision somehow reduced with newer version. Not the first instance I see this, I'll need to check.
  • Loading branch information
AdrienCorenflos committed May 23, 2024
1 parent af7260b commit 36e9f1b
Show file tree
Hide file tree
Showing 20 changed files with 58 additions and 50 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: [ '3.8' ]
python-version: [ '3.12' ]

name: pytest Python ${{ matrix.python-version }}
steps:
Expand Down
22 changes: 13 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
# Parallel square-root statistical linear regression for inference in nonlinear state space models
A generic library for linear and non-linear Gaussian smoothing problems.
The code leverages JAX and implements several linearization algorithms,
both in a sequential and parallel fashion, as well as low-memory cost algorithms computing gradients of required quantities

A generic library for linear and non-linear Gaussian smoothing problems.
The code leverages JAX and implements several linearization algorithms,
both in a sequential and parallel fashion, as well as low-memory cost algorithms computing gradients of required
quantities
(such as the pseudo-loglikelihood of the system).

This code was written by [Adrien Corenflos](https://github.com/AdrienCorenflos) and [Fatemeh Yaghoobi](https://github.com/Fatemeh-Yaghoobi) as a companion code for the article
"Parallel square-root statistical linear regression for inference in nonlinear state space models"
This code was written by [Adrien Corenflos](https://github.com/AdrienCorenflos)
and [Fatemeh Yaghoobi](https://github.com/Fatemeh-Yaghoobi) as a companion code for the article
"Parallel square-root statistical linear regression for inference in nonlinear state space models"
by Fatemeh Yaghoobi, Adrien Corenflos, Sakira Hassan, and Simo Särkkä, ArXiv link: https://arxiv.org/abs/2207.00426

## Installation
## Installation

1. Create a virtual environment and clone this repository
2. Install JAX (preferably with GPU support) following https://github.com/google/jax#installation
3. Run `pip install .`
4. (optional) If you want to run the examples, run `pip install -r examples-requirements.txt`

4. (optional) If you want to run the examples, run `pip install -r requirements-examples.txt`

## Examples

Example uses (reproducing the experiments of our paper) can be found in the [examples folder](../main/notebooks). More low-level examples can be found in the
Example uses (reproducing the experiments of our paper) can be found in the [examples folder](../main/notebooks). More
low-level examples can be found in the
[test folder](../main/tests).

## How to cite

If you find this work useful, please cite us in the following way:

```
Expand Down
7 changes: 2 additions & 5 deletions notebooks/bearing_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_parameters(qc, qw, r, dt, s1, s2):

def _get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states):
for i, a in enumerate(a_s):
# with nb.objmode(x='float32[::1]'):
# with nb.objmode(x='float32[::1]'):
F = np.array([[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, a],
Expand Down Expand Up @@ -189,13 +189,10 @@ def get_data(x0, dt, r, T, s1, s2, q=10., random_state=None):

x = np.copy(x0).astype(np.float32)
observations = np.empty((T, 2), dtype=np.float32)
true_states = np.zeros((T+1, 5), dtype=np.float32)
true_states = np.zeros((T + 1, 5), dtype=np.float32)
ts = np.linspace(dt, (T + 1) * dt, T).astype(np.float32)
true_states[0, :4] = x
normals = random_state.randn(T, 2).astype(np.float32)

_get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states[1:])
return ts, true_states, observations



9 changes: 3 additions & 6 deletions notebooks/bearing_data_pe.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def make_parameters(qc, qw, r, dt, s1, s2):
[0, qc * dt ** 2 / 2, 0, qc * dt, 0],
[0, 0, 0, 0, dt * qw]])

R = jnp.diag(jnp.array([r**2, 0.1**2]))
R = jnp.diag(jnp.array([r ** 2, 0.1 ** 2]))

observation_function = jit(partial(_observation_function, s1=s1, s2=s2))
transition_function = jit(partial(_transition_function, dt=dt))
Expand All @@ -136,7 +136,7 @@ def make_parameters(qc, qw, r, dt, s1, s2):

def _get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states):
for i, a in enumerate(a_s):
# with nb.objmode(x='float32[::1]'):
# with nb.objmode(x='float32[::1]'):
F = np.array([[0, 0, 1, 0],
[0, 0, 0, 1],
[0, 0, 0, a],
Expand Down Expand Up @@ -189,13 +189,10 @@ def get_data(x0, dt, r, T, s1, s2, q=10., random_state=None):

x = np.copy(x0).astype(np.float32)
observations = np.empty((T, 2), dtype=np.float32)
true_states = np.zeros((T+1, 5), dtype=np.float32)
true_states = np.zeros((T + 1, 5), dtype=np.float32)
ts = np.linspace(dt, (T + 1) * dt, T).astype(np.float32)
true_states[0, :4] = x
normals = random_state.randn(T, 2).astype(np.float32)

_get_data(x, dt, a_s, s1, s2, r, normals, observations, true_states[1:])
return ts, true_states, observations



2 changes: 1 addition & 1 deletion notebooks/population_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def get_data(x0: jnp.ndarray, T: int, Q: jnp.ndarray, lam: jnp.ndarray, key: jnp
array of observations
"""
key, gaussian_key = jax.random.split(key)

chol_Q = jnp.linalg.cholesky(Q)
noises = jax.random.normal(gaussian_key, shape=(T, x0.shape[0])) @ chol_Q.T

Expand Down
2 changes: 1 addition & 1 deletion parsmooth/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.0.0"
__version__ = "1.0.0"

from ._base import MVNSqrt, MVNStandard, FunctionalModel, ConditionalMomentsModel
from .methods import filtering, smoothing, iterated_smoothing, filter_smoother, sampling
5 changes: 3 additions & 2 deletions parsmooth/_pathwise_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jlag
from jax.tree_util import tree_map

from parsmooth._base import MVNStandard, MVNSqrt, FunctionalModel
from parsmooth._utils import tria, none_or_shift, none_or_concat
Expand Down Expand Up @@ -66,7 +67,7 @@ def _sampling_common(key: jnp.ndarray,
filter_trajectory: MVNSqrt or MVNStandard,
linearization_method: Callable,
nominal_trajectory: Union[MVNSqrt, MVNStandard]):
last_state = jax.tree_map(lambda z: z[-1], filter_trajectory)
last_state = tree_map(lambda z: z[-1], filter_trajectory)
filter_trajectory = none_or_shift(filter_trajectory, -1)
F_x, cov_or_chol, b = jax.vmap(linearization_method, in_axes=[None, 0])(transition_model,
none_or_shift(nominal_trajectory, -1))
Expand Down Expand Up @@ -99,7 +100,7 @@ def _standard_gain_and_inc(F, Q, b, xf, eps):
mf, Pf = xf

S = F @ Pf @ F.T + Q
gain = Pf @ jlag.solve(S, F, sym_pos=True).T
gain = Pf @ jlag.solve(S, F, assume_a="pos").T

inc_Sig = Pf - gain @ S @ gain.T
inc_m = mf - gain @ (F @ mf + b)
Expand Down
11 changes: 6 additions & 5 deletions parsmooth/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from jax.custom_derivatives import closure_convert
from jax.flatten_util import ravel_pytree
from jax.lax import while_loop
from jax.tree_util import tree_map


def cholesky_update_many(chol_init, update_vectors, multiplier):
Expand Down Expand Up @@ -83,17 +84,17 @@ def none_or_shift(x, shift):
if x is None:
return None
if shift > 0:
return jax.tree_map(lambda z: z[shift:], x)
return jax.tree_map(lambda z: z[:shift], x)
return tree_map(lambda z: z[shift:], x)
return tree_map(lambda z: z[:shift], x)


def none_or_concat(x, y, position=1):
if x is None or y is None:
return None
if position == 1:
return jax.tree_map(lambda a, b: jnp.concatenate([a[None, ...], b]), y, x)
return tree_map(lambda a, b: jnp.concatenate([a[None, ...], b]), y, x)
else:
return jax.tree_map(lambda a, b: jnp.concatenate([b, a[None, ...]]), y, x)
return tree_map(lambda a, b: jnp.concatenate([b, a[None, ...]]), y, x)


# FIXED POINT UTIL
Expand All @@ -120,7 +121,7 @@ def _fixed_point_rev(f, _criterion, res, x_star_bar):
(params, x_star, x_star_bar),
x_star_bar,
lambda i, *_: i < n_iter + 1)[0])
return theta_bar, jax.tree_map(jnp.zeros_like, x_star)
return theta_bar, tree_map(jnp.zeros_like, x_star)


def _rev_iter(f, u, *packed):
Expand Down
2 changes: 1 addition & 1 deletion parsmooth/linearization/_gh.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def _gauss_hermite_weights(n_dim: int, order: int = 3) -> Tuple[np.ndarray, np.n

w_1d = np.zeros(shape=(p,))
for i in range(p):
w_1d[i] = (2 ** (p - 1) * np.math.factorial(p) * np.sqrt(np.pi) /
w_1d[i] = (2 ** (p - 1) * math.factorial(p) * np.sqrt(np.pi) /
(p ** 2 * (np.polyval(hermite_coeff[p - 1],
hermite_roots[i])) ** 2))

Expand Down
11 changes: 6 additions & 5 deletions parsmooth/parallel/_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jlinalg
from jax.tree_util import tree_map

from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt, are_inputs_compatible, ConditionalMomentsModel
from parsmooth._utils import tria, none_or_concat, mvn_loglikelihood
Expand Down Expand Up @@ -63,8 +64,8 @@ def filtering(observations: jnp.ndarray,
def _standard_associative_params(linearization_method, transition_model, observation_model,
nominal_trajectory, x0, ys):
T = ys.shape[0]
n_k_1 = jax.tree_map(lambda z: z[:-1], nominal_trajectory)
n_k = jax.tree_map(lambda z: z[1:], nominal_trajectory)
n_k_1 = tree_map(lambda z: z[:-1], nominal_trajectory)
n_k = tree_map(lambda z: z[1:], nominal_trajectory)

m0, P0 = x0
ms = jnp.concatenate([m0[None, ...], jnp.zeros_like(m0, shape=(T - 1,) + m0.shape)])
Expand All @@ -82,7 +83,7 @@ def _standard_associative_params_one(linearization_method, transition_model, obs
P = F @ P @ F.T + Q

S = H @ P @ H.T + R
S_invH = jlinalg.solve(S, H, sym_pos=True)
S_invH = jlinalg.solve(S, H, assume_a="pos")
K = (S_invH @ P).T
A = F - K @ H @ F

Expand All @@ -99,8 +100,8 @@ def _standard_associative_params_one(linearization_method, transition_model, obs
def _sqrt_associative_params(linearization_method, transition_model, observation_model,
nominal_trajectory, x0, ys):
T = ys.shape[0]
n_k_1 = jax.tree_map(lambda z: z[:-1], nominal_trajectory)
n_k = jax.tree_map(lambda z: z[1:], nominal_trajectory)
n_k_1 = tree_map(lambda z: z[:-1], nominal_trajectory)
n_k = tree_map(lambda z: z[1:], nominal_trajectory)

m0, L0 = x0
ms = jnp.concatenate([m0[None, ...], jnp.zeros_like(m0, shape=(T - 1,) + m0.shape)])
Expand Down
4 changes: 2 additions & 2 deletions parsmooth/parallel/_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def standard_filtering_operator(elem1, elem2):
IpCJ = I_dim + jnp.dot(C1, J2)
IpJC = I_dim + jnp.dot(J2, C1)

AIpCJ_inv = jlinalg.solve(IpCJ.T, A2.T, sym_pos=False).T
AIpJC_inv = jlinalg.solve(IpJC.T, A1, sym_pos=False).T
AIpCJ_inv = jlinalg.solve(IpCJ.T, A2.T).T
AIpJC_inv = jlinalg.solve(IpJC.T, A1).T

A = jnp.dot(AIpCJ_inv, A1)
b = jnp.dot(AIpCJ_inv, b1 + jnp.dot(C1, eta2)) + b2
Expand Down
5 changes: 3 additions & 2 deletions parsmooth/parallel/_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jlinalg
from jax.tree_util import tree_map

from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt, are_inputs_compatible, ConditionalMomentsModel
from parsmooth._utils import none_or_concat, tria
Expand Down Expand Up @@ -46,7 +47,7 @@ def smoothing(transition_model: Union[FunctionalModel, ConditionalMomentsModel],
def _associative_params(linearization_method, transition_model,
nominal_trajectory, filtering_trajectory, sqrt):
ms, Ps = filtering_trajectory
nominal_trajectory = jax.tree_map(lambda z: z[:-1], nominal_trajectory)
nominal_trajectory = tree_map(lambda z: z[:-1], nominal_trajectory)
if sqrt:
vmapped_fn = jax.vmap(_sqrt_associative_params, in_axes=[None, None, 0, 0, 0])
else:
Expand All @@ -60,7 +61,7 @@ def _standard_associative_params(linearization_method, transition_model, n_k_1,
F, Q, b = linearization_method(transition_model, n_k_1)
Pp = F @ P @ F.T + Q

E = jlinalg.solve(Pp, F @ P, sym_pos=True).T
E = jlinalg.solve(Pp, F @ P, assume_a="pos").T

g = m - E @ (F @ m + b)
L = P - E @ Pp @ E.T
Expand Down
5 changes: 3 additions & 2 deletions parsmooth/sequential/_smoothing.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jlag
from jax.tree_util import tree_map

from parsmooth._base import MVNStandard, MVNSqrt, are_inputs_compatible, FunctionalModel, ConditionalMomentsModel
from parsmooth._utils import tria, none_or_shift, none_or_concat
Expand All @@ -12,7 +13,7 @@ def smoothing(transition_model: Union[FunctionalModel, ConditionalMomentsModel],
filter_trajectory: Union[MVNSqrt, MVNStandard],
linearization_method: Callable,
nominal_trajectory: Optional[Union[MVNSqrt, MVNStandard]] = None):
last_state = jax.tree_map(lambda z: z[-1], filter_trajectory)
last_state = tree_map(lambda z: z[-1], filter_trajectory)

if nominal_trajectory is not None:
are_inputs_compatible(filter_trajectory, nominal_trajectory)
Expand Down Expand Up @@ -49,7 +50,7 @@ def _standard_smooth(F, Q, b, xf, xs):
S = F @ Pf @ F.T + Q
cov_diff = Ps - S

gain = Pf @ jlag.solve(S, F, sym_pos=True).T
gain = Pf @ jlag.solve(S, F, assume_a="pos").T
ms = mf + gain @ mean_diff
Ps = Pf + gain @ cov_diff @ gain.T

Expand Down
2 changes: 2 additions & 0 deletions requirements-examples.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
matplotlib==3.9.0
pandas==2.2.2
2 changes: 1 addition & 1 deletion requirements-test.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@ pytest-cov
pytest-forked
pytest-html
pytest-xdist
tfp-nightly
tensorflow-probability
6 changes: 5 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
numpy
jax==0.4.28
jaxlib==0.4.28
numpy==1.26.4
scipy==1.13.0
tensorflow-probability==0.24.0
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def get_version(rel_path):
name="parsmooth",
author="Adrien Corenflos",
version=get_version("parsmooth/__init__.py"),
python_requires='>=3.8',
description="Parallel non-linear smoothing and parameter estimation for state space models",
long_description=long_description,
packages=setuptools.find_packages(),
Expand Down
4 changes: 1 addition & 3 deletions tests/test_parallel_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ def test_vs_sequential_filter(dim_x, dim_y, seed, linearization_method):
np.testing.assert_array_almost_equal(sqrt_par_filter_res.mean, seq_filter_res.mean)
np.testing.assert_array_almost_equal(sqrt_par_filter_res.chol @ np.transpose(sqrt_par_filter_res.chol, [0, 2, 1]),
seq_filter_res.cov)
np.testing.assert_array_almost_equal(seq_filter_res.mean, seq_sqrt_filter_res.mean)
np.testing.assert_array_almost_equal(seq_filter_res.mean, seq_sqrt_filter_res.mean, decimal=5)

assert seq_sqrt_ell == pytest.approx(seq_ell)
assert par_ell == pytest.approx(seq_ell)
assert par_sqrt_ell == pytest.approx(seq_ell)


4 changes: 2 additions & 2 deletions tests/test_sequential_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tests._lgssm import get_data, transition_function as lgssm_f, observation_function as lgssm_h
from tests._test_utils import get_system

LIST_LINEARIZATIONS = [cubature]
LIST_LINEARIZATIONS = [cubature, extended]


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -76,7 +76,7 @@ def test_update_value(dim_x, dim_y, seed, sqrt):

res = y - H @ x.mean - c
S = H @ x.cov @ H.T + R
K = x.cov @ solve(S, H, sym_pos=True).T
K = x.cov @ solve(S, H, assume_a="pos").T
np.testing.assert_allclose(next_x.mean, x.mean + K @ res, atol=1e-1)
np.testing.assert_allclose(cov, x.cov - K @ H @ x.cov, atol=1e-5)

Expand Down
2 changes: 1 addition & 1 deletion tests/test_sequential_smoother.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_smooth_one_value(dim_x, seed, sqrt):

m_ = F @ xf.mean + b
P_ = F @ xf.cov @ F.T + Q
G = xf.cov @ solve(P_.T, F, sym_pos=True).T
G = xf.cov @ solve(P_.T, F, assume_a="pos").T
ms = xf.mean + G @ (xs.mean - m_)
Ps = xf.cov + G @ (xs.cov - P_) @ G.T

Expand Down

0 comments on commit 36e9f1b

Please sign in to comment.