Skip to content

Commit 484bc3b

Browse files
authored
Merge pull request #15 from kirkegaardlab/fix_jax_deprecation
fix: replace deprecated jax.numpy.trapz
2 parents 80c1c06 + 16a6c38 commit 484bc3b

File tree

5 files changed

+23
-9
lines changed

5 files changed

+23
-9
lines changed

.github/workflows/integration_test.yml

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414
steps:
1515
- uses: actions/checkout@v2
1616

17-
- name: Set up Python 3.8
17+
- name: Set up Python 3.10
1818
uses: actions/setup-python@v2
1919
with:
20-
python-version: '3.8'
20+
python-version: '3.10'
2121
architecture: 'x64'
2222

2323
- name: apt-get

celegans/simulation.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import jax
88
import jax.numpy as jnp
9+
from jax.scipy.integrate import trapezoid
910

1011

1112
def _theta(t, s, params):
@@ -123,9 +124,9 @@ def solve(t, u, X, ds, alpha):
123124
fx = Ut * tx[jnp.newaxis] + alpha * Un * nx[jnp.newaxis]
124125
fy = Ut * ty[jnp.newaxis] + alpha * Un * ny[jnp.newaxis]
125126

126-
Fx = jnp.trapz(fx, dx=ds)
127-
Fy = jnp.trapz(fy, dx=ds)
128-
Tau = jnp.trapz(x * fy - y * fx, dx=ds)
127+
Fx = trapezoid(fx, dx=ds)
128+
Fy = trapezoid(fy, dx=ds)
129+
Tau = trapezoid(x * fy - y * fx, dx=ds)
129130

130131
b = -jnp.array([Fx[0], Fy[0], Tau[0]])
131132
A = jnp.array([Fx[1:], Fy[1:], Tau[1:]])

examples/detect.py

+6
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,12 @@
22
import deeptangle as dt
33
import matplotlib.pyplot as plt
44
from skimage.exposure import equalize_adapthist
5+
import numpy
6+
7+
# scikit-video uses deprecated numpy.float, numpy.int
8+
# hacky fix: https://github.com/scikit-video/scikit-video/issues/154
9+
numpy.float = numpy.float64
10+
numpy.int = numpy.int_
511
import skvideo.io
612

713

examples/track.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@
99
import matplotlib.pyplot as plt
1010
import numpy as np
1111
from skimage.exposure import equalize_adapthist
12+
13+
# scikit-video uses deprecated numpy.float, numpy.int
14+
# hacky fix: https://github.com/scikit-video/scikit-video/issues/154
15+
import numpy
16+
numpy.float = numpy.float64
17+
numpy.int = numpy.int_
1218
import skvideo.io
1319

1420
import deeptangle as dt

requirements.txt

+5-4
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@ scikit-image
33
scikit-video
44
optax
55
chex
6-
jax
6+
jax>=0.4.16
7+
jaxlib>=0.4.20
78
dm-pix
89
scikit-learn
9-
numpy==1.21.6
10-
numba==0.55
10+
numpy>=1.21.6
11+
numba>=0.55
1112
matplotlib
12-
https://github.com/alonfnt/dm-haiku/archive/refs/heads/avg_pool_perf.zip
13+
dm-haiku
1314
trackpy

0 commit comments

Comments
 (0)