Skip to content

Commit

Permalink
Fix instability (#60)
Browse files Browse the repository at this point in the history
* Fix instability by replacing QR handmade by the now supported CUDA one.

* rerun the robustness experiment with 100 runs

---------

Co-authored-by: Adrien Corenflos <[email protected]>
  • Loading branch information
Fatemeh-Yaghoobi and AdrienCorenflos authored Jul 1, 2024
1 parent 36e9f1b commit cc7c6cf
Show file tree
Hide file tree
Showing 110 changed files with 235 additions and 94 deletions.
180 changes: 180 additions & 0 deletions notebooks/robustness_100runs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
import jax
import numpy as np
import matplotlib.pyplot as plt
from jax import jit
import jax.numpy as jnp

jax.config.update("jax_enable_x64", False)

from parsmooth._base import MVNStandard, FunctionalModel, MVNSqrt
from parsmooth.linearization import cubature, extended
from parsmooth.methods import iterated_smoothing
from bearing_data import get_data, make_parameters

s1 = jnp.array([-1.5, 0.5]) # First sensor location
s2 = jnp.array([1., 1.]) # Second sensor location
r = 0.5 # Observation noise (stddev)
x0 = jnp.array([0.1, 0.2, 1, 0]) # initial true location
dt = 0.01 # discretization time step
qc = 0.01 # discretization noise
qw = 0.1 # discretization noise

Q, R, observation_function, transition_function = make_parameters(qc, qw, r, dt, s1, s2)

chol_Q = jnp.linalg.cholesky(Q)
chol_R = jnp.linalg.cholesky(R)

m0 = jnp.array([-4., -1., 2., 7., 3.])
chol_P0 = jnp.eye(5)
P0 = jnp.eye(5)

init = MVNStandard(m0, P0)
chol_init = MVNSqrt(m0, chol_P0)


sqrt_transition_model = FunctionalModel(transition_function, MVNSqrt(jnp.zeros((5,)), chol_Q))
transition_model = FunctionalModel(transition_function, MVNStandard(jnp.zeros((5,)), Q))

sqrt_observation_model = FunctionalModel(observation_function, MVNSqrt(jnp.zeros((2,)), chol_R))
observation_model = FunctionalModel(observation_function, MVNStandard(jnp.zeros((2,)), R))

n_run = 100
# T = 8000
# for i in range(n_run):
# _, _, ys = get_data(x0, dt, r, T, s1, s2)
# jnp.savez(f"robustness_data/data-T{T}-Run{i + 1}.npz", data=ys)


Ts = [20, 30, 40, 50, 80, 100, 200, 300, 400, 500, 1000, 1500, 2000, 2500, 3000, 3500, 4000, 4500, 5000, 5500, 6000, 6500, 7000, 7500]

def func(method, Ts, runtime=n_run, n_iter=20, sqrt=True, mth="extended_std"):
ell_par = []
for i, T in enumerate(Ts):
print(f"Length {i + 1} out of {len(Ts)}")
ell_par_res = []
for j in range(runtime):
with np.load("robustness_data/data-T" + str(8000) + "-Run" + str(j + 1) + ".npz") as data:
ys = data["data"][0:T]

if sqrt:
initial_states_sqrt = MVNSqrt(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]), T + 1, axis=0),
jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
args = ys, initial_states_sqrt, n_iter

else:
initial_states = MVNStandard(jnp.repeat(jnp.array([[-1., -1., 6., 4., 2.]]), T + 1, axis=0),
jnp.repeat(jnp.eye(5).reshape(1, 5, 5), T + 1, axis=0))
args = ys, initial_states, n_iter

_, ell = method(*args)

ell_par_res.append(ell)
print(f"run {j + 1} out of {runtime}", end="\r")

ell_par.append(ell_par_res)
print()

return ell_par
# Extended
def IEKS_std_par(observations, initial_points, iteration):
std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
extended, initial_points, True,
criterion=lambda i, *_: i < iteration,
return_loglikelihood = True)
return std_par_res, ell


def IEKS_sqrt_par(observations, initial_points_sqrt, iteration):
sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
extended, initial_points_sqrt, True,
criterion=lambda i, *_: i < iteration,
return_loglikelihood = True)
return sqrt_par_res, ell

gpu_IEKS_std_par = jit(IEKS_std_par, backend="gpu")
gpu_IEKS_sqrt_par = jit(IEKS_sqrt_par, backend="gpu")
## Extended Standard
# gpu_IEKS_std_par_ell = func(gpu_IEKS_std_par, Ts, sqrt=False)
# jnp.savez("robustness_data/ell_float32_extended_std_100runs", rts_gpu_IEKS_std_par_ell=gpu_IEKS_std_par_ell)

## Extended Sqrt
# gpu_IEKS_sqrt_par_ell = func(gpu_IEKS_sqrt_par, Ts, sqrt=True, mth="extended_sqrt")
# jnp.savez("robustness_data/ell_float32_extended_sqrt_100runs",
# rts_gpu_IEKS_sqrt_par_ell=gpu_IEKS_sqrt_par_ell)
# Cubature
def ICKS_std_par(observations, initial_points, iteration):
std_par_res, ell = iterated_smoothing(observations, init, transition_model, observation_model,
cubature, initial_points, True,
criterion=lambda i, *_: i < iteration,
return_loglikelihood = True)
return std_par_res, ell


def ICKS_sqrt_par(observations, initial_points_sqrt, iteration):
sqrt_par_res, ell = iterated_smoothing(observations, chol_init, sqrt_transition_model, sqrt_observation_model,
cubature, initial_points_sqrt, True,
criterion=lambda i, *_: i < iteration,
return_loglikelihood = True)
return sqrt_par_res, ell

gpu_ICKS_std_par = jit(ICKS_std_par, backend="gpu")
gpu_ICKS_sqrt_par = jit(ICKS_sqrt_par, backend="gpu")

##Cubature Standard
# gpu_ICKS_std_par_ell = func(gpu_ICKS_std_par, Ts, sqrt=False, mth="cubature_std")
# jnp.savez("robustness_data/ell_float32_cubature_std_100runs", gpu_ICKS_std_par_ell=gpu_ICKS_std_par_ell)

## Cubature Sqrt
# gpu_ICKS_sqrt_par_ell = func(gpu_ICKS_sqrt_par, Ts, sqrt=True, mth="cubature_sqrt")
# jnp.savez("robustness_data/ell_float32_cubature_sqrt_100runs", gpu_ICKS_sqrt_par_ell=gpu_ICKS_sqrt_par_ell)


with np.load("robustness_data/ell_float32_extended_std_100runs.npz") as data:
gpu_IEKS_std_par_ell = data["rts_gpu_IEKS_std_par_ell"]
with np.load("robustness_data/ell_float32_extended_sqrt_100runs.npz") as data:
gpu_IEKS_sqrt_par_ell = data["rts_gpu_IEKS_sqrt_par_ell"]

plt.figure()
plt.plot(Ts, np.mean(np.isnan(gpu_IEKS_std_par_ell), axis=1)*100, '*--', label="Standard IEKS")
plt.plot(Ts, np.mean(np.isnan(gpu_IEKS_sqrt_par_ell), axis=1)*100, '*--', label="Sqrt IEKS")
plt.title("Percentage of NaNs: 100 runs")
plt.legend()
plt.show()

with np.load("robustness_data/ell_float32_cubature_sqrt_100runs.npz") as data:
gpu_ICKS_sqrt_par_ell = data["gpu_ICKS_sqrt_par_ell"]
with np.load("robustness_data/ell_float32_cubature_std_100runs.npz") as data:
gpu_ICKS_std_par_ell = data["gpu_ICKS_std_par_ell"]
plt.figure()
plt.plot(Ts, np.mean(np.isnan(gpu_ICKS_std_par_ell), axis=1)*100, '*--', label="Standard ICKS")
plt.plot(Ts, np.mean(np.isnan(gpu_ICKS_sqrt_par_ell), axis=1)*100, '*--', label="Sqrt ICKS")
plt.grid()
plt.legend()
plt.title("Percentage of NaNs: 100 runs")
plt.show()

## Make CSV files
import pandas as pd
data_rts_cubature = np.stack([Ts,
np.mean(np.isnan(gpu_ICKS_sqrt_par_ell), axis=1)*100,
np.mean(np.isnan(gpu_ICKS_std_par_ell), axis=1)*100],
axis=1)

columns = ["observations",
"fl32_gpu_cubature_sqrt_par_ell",
"fl32_gpu_cubature_std_par_ell"]

df1 = pd.DataFrame(data=data_rts_cubature, columns=columns)
# df1.to_csv("robustness_data/R1_rts_fl32_cubature_ell.csv")

data_rts_extended = np.stack([Ts,
np.mean(np.isnan(gpu_IEKS_sqrt_par_ell), axis=1)*100,
np.mean(np.isnan(gpu_IEKS_std_par_ell), axis=1)*100],
axis=1)

columns = ["observations",
"fl32_gpu_extended_sqrt_par_ell",
"fl32_gpu_extended_std_par_ell"]

df2 = pd.DataFrame(data=data_rts_extended, columns=columns)
# df2.to_csv("robustness_data/R1_rts_fl32_extended_ell.csv")
25 changes: 25 additions & 0 deletions notebooks/robustness_data/R1_rts_fl32_cubature_ell.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
,observations,fl32_gpu_cubature_sqrt_par_ell,fl32_gpu_cubature_std_par_ell
0,20.0,0.0,0.0
1,30.0,0.0,0.0
2,40.0,0.0,2.0
3,50.0,0.0,4.0
4,80.0,0.0,19.0
5,100.0,0.0,100.0
6,200.0,0.0,100.0
7,300.0,0.0,100.0
8,400.0,0.0,100.0
9,500.0,0.0,100.0
10,1000.0,0.0,100.0
11,1500.0,0.0,100.0
12,2000.0,0.0,100.0
13,2500.0,0.0,100.0
14,3000.0,0.0,100.0
15,3500.0,0.0,100.0
16,4000.0,0.0,100.0
17,4500.0,1.0,100.0
18,5000.0,0.0,100.0
19,5500.0,0.0,100.0
20,6000.0,0.0,100.0
21,6500.0,1.0,100.0
22,7000.0,2.0,100.0
23,7500.0,2.0,100.0
25 changes: 25 additions & 0 deletions notebooks/robustness_data/R1_rts_fl32_extended_ell.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
,observations,fl32_gpu_extended_sqrt_par_ell,fl32_gpu_extended_std_par_ell
0,20.0,0.0,0.0
1,30.0,0.0,0.0
2,40.0,0.0,0.0
3,50.0,0.0,0.0
4,80.0,0.0,0.0
5,100.0,0.0,0.0
6,200.0,0.0,0.0
7,300.0,0.0,0.0
8,400.0,0.0,0.0
9,500.0,0.0,0.0
10,1000.0,0.0,0.0
11,1500.0,0.0,0.0
12,2000.0,0.0,0.0
13,2500.0,0.0,0.0
14,3000.0,0.0,0.0
15,3500.0,0.0,1.0
16,4000.0,0.0,0.0
17,4500.0,0.0,1.0
18,5000.0,0.0,1.0
19,5500.0,0.0,1.0
20,6000.0,0.0,1.0
21,6500.0,0.0,1.0
22,7000.0,0.0,2.0
23,7500.0,0.0,2.0
Binary file added notebooks/robustness_data/data-T8000-Run1.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run10.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run100.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run11.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run12.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run13.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run14.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run15.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run16.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run17.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run18.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run19.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run2.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run20.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run21.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run22.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run23.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run24.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run25.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run26.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run27.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run28.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run29.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run3.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run30.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run31.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run32.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run33.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run34.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run35.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run36.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run37.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run38.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run39.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run4.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run40.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run41.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run42.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run43.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run44.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run45.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run46.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run47.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run48.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run49.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run5.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run50.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run51.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run52.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run53.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run54.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run55.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run56.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run57.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run58.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run59.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run6.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run60.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run61.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run62.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run63.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run64.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run65.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run66.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run67.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run68.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run69.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run7.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run70.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run71.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run72.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run73.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run74.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run75.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run76.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run77.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run78.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run79.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run8.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run80.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run81.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run82.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run83.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run84.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run85.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run86.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run87.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run88.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run89.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run9.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run90.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run91.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run92.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run93.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run94.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run95.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run96.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run97.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run98.npz
Binary file not shown.
Binary file added notebooks/robustness_data/data-T8000-Run99.npz
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
76 changes: 2 additions & 74 deletions parsmooth/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ def body(chol, update_vector):


def tria(A):
return qr(A.T).T
_, R = jlinalg.qr(A.T, mode='economic')
return R.T


def _set_diagonal(x, y):
Expand Down Expand Up @@ -157,76 +158,3 @@ def mvn_loglikelihood(x, chol_cov):
)
norm_y = jnp.sum(y * y, -1)
return -0.5 * norm_y - normalizing_constant


@jax.custom_jvp
def qr(A: jnp.ndarray):
"""The JAX provided implementation is not parallelizable using VMAP. As a consequence, we have to rewrite it..."""
return _qr(A)


# @partial(jax.jit, static_argnums=(1,))
def _qr(A: jnp.ndarray, return_q=False):
m, n = A.shape
min_ = min(m, n)
if return_q:
Q = jnp.eye(m)

for j in range(min_):
# Apply Householder transformation.
v, tau = _householder(A[j:, j])

H = jnp.eye(m)
H = H.at[j:, j:].add(-tau * (v[:, None] @ v[None, :]))

A = H @ A
if return_q:
Q = H @ Q # noqa

R = jnp.triu(A[:min_, :min_])
if return_q:
return Q[:n].T, R # noqa
else:
return R


def _householder(a):
if a.dtype == jnp.float64:
eps = 1e-9
else:
eps = 1e-7

alpha = a[0]
s = jnp.sum(a[1:] ** 2)
cond = s < eps

def if_not_cond(v):
t = (alpha ** 2 + s) ** 0.5
v0 = jax.lax.cond(alpha <= 0, lambda _: alpha - t, lambda _: -s / (alpha + t), None)
tau = 2 * v0 ** 2 / (s + v0 ** 2)
v = v / v0
v = v.at[0].set(1.)
return v, tau

return jax.lax.cond(cond, lambda v: (v, 0.), if_not_cond, a)


def qr_jvp_rule(primals, tangents):
x, = primals
dx, = tangents
q, r = _qr(x, True)
m, n = x.shape
min_ = min(m, n)
if m < n:
dx = dx[:, :m]
dx_rinv = jax.lax.linalg.triangular_solve(r, dx)
qt_dx_rinv = jnp.matmul(q.T, dx_rinv)
qt_dx_rinv_lower = jnp.tril(qt_dx_rinv, -1)
do = qt_dx_rinv_lower - qt_dx_rinv_lower.T # This is skew-symmetric
# The following correction is necessary for complex inputs
do = do + jnp.eye(min_, dtype=do.dtype) * (qt_dx_rinv - jnp.real(qt_dx_rinv))
dr = jnp.matmul(qt_dx_rinv - do, r)
return r, dr


qr.defjvp(qr_jvp_rule)
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
jax==0.4.28
jaxlib==0.4.28
jax==0.4.30
jaxlib==0.4.30
numpy==1.26.4
scipy==1.13.0
tensorflow-probability==0.24.0
19 changes: 1 addition & 18 deletions tests/test_math_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import jax
import jax.numpy as jnp
import numpy as np
import pytest
import tensorflow_probability.substrates.jax as tfp
from jax.test_util import check_grads

from parsmooth._utils import _cholesky_update, cholesky_update_many, fixed_point, _qr, qr
from parsmooth._utils import _cholesky_update, cholesky_update_many, fixed_point


@pytest.fixture(scope="session", autouse=True)
Expand Down Expand Up @@ -55,22 +54,6 @@ def test_cholesky_update_many(multiplier, seed, dim_x):
np.testing.assert_allclose(cholRes @ cholRes.T, expected, rtol=1e-4)


@pytest.mark.parametrize("seed", [0, 1, 2, 3])
def test_qr(seed):
np.random.seed(seed)
A = np.random.randn(3, 2)
B = np.random.randn(2, 3)

q, r = _qr(jnp.array(A), True)
np.testing.assert_allclose(q @ r, A)

q, r = _qr(jnp.array(B), True)
np.testing.assert_allclose((q @ r), B[:, :2])

check_grads(qr, (A,), 1, modes=["rev", "fwd"])
check_grads(qr, (B,), 1, modes=["rev", "fwd"])


def test_fixed_point():
def my_fun(a, b, x0):
f = lambda x: (a * x[0] + b[0],)
Expand Down

0 comments on commit cc7c6cf

Please sign in to comment.