Skip to content

Commit

Permalink
🧪 Test examples in the documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jun 19, 2024
1 parent 7c05693 commit 0852970
Show file tree
Hide file tree
Showing 16 changed files with 93 additions and 45 deletions.
14 changes: 14 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,17 @@ jobs:
pip install .
- name: Run tests
run: pytest tests
doctest:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.10"
- name: Install dependencies
run: |
pip install pytest
pip install torch==2.0 --extra-index-url https://download.pytorch.org/whl/cpu
pip install .
- name: Run doctests
run: pytest zuko --doctest-modules
6 changes: 6 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ We use [pytest](https://docs.pytest.org) to test our code base. If your contribu
pytest tests
```

Additionally, examples in the documentation are tested with

```
pytest zuko --doctest-modules
```

When you submit a pull request, tests are automatically (upon approval) executed for several versions of Python and PyTorch.

### Code conventions
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ exclude = ["*.ipynb"]
preview = true
quote-style = "preserve"

[tool.setuptools]
packages = ["zuko"]

[tool.setuptools.dynamic]
dependencies = {file = "requirements.txt"}
version = {attr = "zuko.__version__"}
12 changes: 12 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
r"""Tests configuration."""

import pytest
import torch


@pytest.fixture(autouse=True, scope='module')
def torch_float64():
try:
yield torch.set_default_dtype(torch.float64)
finally:
torch.set_default_dtype(torch.float32)
2 changes: 0 additions & 2 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@
from torch import randn
from zuko.flows import *

torch.set_default_dtype(torch.float64)


@pytest.mark.parametrize('F', [GMM, NICE, MAF, NSF, SOSPF, NAF, UNAF, CNF, GF, BPF])
def test_flows(tmp_path: Path, F: callable):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from torch.distributions import *
from zuko.transforms import *

torch.set_default_dtype(torch.float64)


@pytest.mark.parametrize('batched', [False, True])
def test_univariate_transforms(batched: bool):
Expand Down
2 changes: 0 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@
from torch import randn
from zuko.utils import *

torch.set_default_dtype(torch.float64)


def test_bisection():
alpha = torch.tensor(1.0, requires_grad=True)
Expand Down
17 changes: 17 additions & 0 deletions zuko/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
r"""Doctests configuration."""

import pytest
import torch
import zuko


@pytest.fixture(autouse=True, scope='module')
def doctest_imports(doctest_namespace):
doctest_namespace['torch'] = torch
doctest_namespace['zuko'] = zuko


@pytest.fixture(autouse=True)
def torch_seed():
with torch.random.fork_rng():
yield torch.random.manual_seed(0)
24 changes: 12 additions & 12 deletions zuko/distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class NormalizingFlow(Distribution):
Example:
>>> d = NormalizingFlow(ExpTransform(), Gamma(2.0, 1.0))
>>> d.sample()
tensor(1.1316)
tensor(1.5157)
"""

has_rsample = True
Expand Down Expand Up @@ -153,7 +153,7 @@ class Joint(Distribution):
>>> d.event_shape
torch.Size([2])
>>> d.sample()
tensor([ 0.8969, -2.6717])
tensor([0.4963, 0.2072])
"""

has_rsample = True
Expand Down Expand Up @@ -231,7 +231,7 @@ class Mixture(Distribution):
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(2.8732)
tensor(-1.6920)
"""

has_rsample = False
Expand Down Expand Up @@ -301,7 +301,7 @@ class GeneralizedNormal(Distribution):
Example:
>>> d = GeneralizedNormal(2.0)
>>> d.sample()
tensor(0.7480)
tensor(-0.0281)
"""

arg_constraints = {'beta': constraints.positive}
Expand Down Expand Up @@ -350,7 +350,7 @@ class DiagNormal(Independent):
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([ 0.7304, -0.1976, -1.7591])
tensor([ 1.5410, -0.2934, -2.1788])
"""

def __init__(self, loc: Tensor, scale: Tensor, ndims: int = 1):
Expand Down Expand Up @@ -383,7 +383,7 @@ class BoxUniform(Independent):
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([ 0.1859, -0.9698, 0.0665])
tensor([-0.0075, 0.5364, -0.8230])
"""

def __init__(self, lower: Tensor, upper: Tensor, ndims: int = 1):
Expand Down Expand Up @@ -416,7 +416,7 @@ class TransformedUniform(NormalizingFlow):
Example:
>>> d = TransformedUniform(ExpTransform(), -1.0, 1.0)
>>> d.sample()
tensor(0.5594)
tensor(0.4281)
"""

def __init__(self, f: Transform, lower: Tensor, upper: Tensor):
Expand Down Expand Up @@ -445,7 +445,7 @@ class Truncated(Distribution):
Example:
>>> d = Truncated(Normal(0.0, 1.0), 1.0, 2.0)
>>> d.sample()
tensor(1.2573)
tensor(1.3333)
"""

has_rsample = True
Expand Down Expand Up @@ -509,7 +509,7 @@ class Sort(Distribution):
>>> d.event_shape
torch.Size([3])
>>> d.sample()
tensor([-1.4434, -0.3861, 0.2439])
tensor([-2.1788, -0.2934, 1.5410])
"""

def __init__(
Expand Down Expand Up @@ -590,7 +590,7 @@ class TopK(Sort):
>>> d.event_shape
torch.Size([2])
>>> d.sample()
tensor([-0.2167, 0.6739])
tensor([-2.1788, -0.2934])
"""

def __init__(
Expand Down Expand Up @@ -646,7 +646,7 @@ class Minimum(TopK):
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(-1.7552)
tensor(-2.1788)
"""

def __init__(self, base: Distribution, n: int = 2):
Expand Down Expand Up @@ -687,7 +687,7 @@ class Maximum(Minimum):
>>> d.event_shape
torch.Size([])
>>> d.sample()
tensor(1.1644)
tensor(1.5410)
"""

def __init__(self, base: Distribution, n: int = 2):
Expand Down
8 changes: 4 additions & 4 deletions zuko/flows/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,11 +59,11 @@ class MaskedAutoregressiveTransform(LazyTransform):
)
>>> x = torch.randn(3)
>>> x
tensor([-0.9485, 1.5290, 0.2018])
tensor([ 1.7428, -1.6483, -0.9920])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([-0.9485, 1.5290, 0.2018])
tensor([ 1.7428, -1.6483, -0.9920], grad_fn=<DivBackward0>)
"""

def __new__(
Expand Down Expand Up @@ -205,9 +205,9 @@ class MAF(Flow):
>>> c = torch.randn(4)
>>> x = flow(c).sample()
>>> x
tensor([-1.7154, -0.4401, 0.7505])
tensor([-0.5005, -1.6303, 0.3805])
>>> flow(c).log_prob(x)
tensor(-4.4630, grad_fn=<AddBackward0>)
tensor(-3.7514, grad_fn=<AddBackward0>)
"""

def __init__(
Expand Down
5 changes: 3 additions & 2 deletions zuko/flows/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,12 @@ class FFJTransform(LazyTransform):
)
>>> x = torch.randn(3)
>>> x
tensor([ 0.1777, 1.0139, -1.0370])
tensor([ 0.6365, -0.3181, 1.1519])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([ 0.1777, 1.0139, -1.0370])
tensor([ 0.6364, -0.3181, 1.1519],
grad_fn=<AdaptiveCheckpointAdjointBackward>)
"""

def __init__(
Expand Down
9 changes: 5 additions & 4 deletions zuko/flows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,19 +181,20 @@ class Unconditional(nn.Module):
kwargs: The keyword arguments passed to `meta`.
Examples:
>>> f = zuko.distributions.DiagNormal
>>> mu, sigma = torch.zeros(3), torch.ones(3)
>>> d = Unconditional(DiagNormal, mu, sigma, buffer=True)
>>> d = Unconditional(f, mu, sigma, buffer=True)
>>> d()
DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> d().sample()
tensor([-0.6687, -0.9690, 1.7461])
tensor([ 1.5410, -0.2934, -2.1788])
>>> t = Unconditional(ExpTransform)
>>> t = Unconditional(zuko.transforms.ExpTransform)
>>> t()
ExpTransform()
>>> x = torch.randn(3)
>>> t()(x)
tensor([0.5523, 0.7997, 0.9189])
tensor([1.7655, 0.3381, 0.2469])
"""

def __init__(
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,11 @@ class GeneralCouplingTransform(LazyTransform):
)
>>> x = torch.randn(3)
>>> x
tensor([-0.8743, 0.6232, 1.2439])
tensor([-0.7900, -0.3259, -1.3184])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([-0.8743, 0.6232, 1.2439])
tensor([-0.7900, -0.3259, -1.3184], grad_fn=<IndexPutBackward0>)
"""

def __new__(
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/gaussianization.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,11 @@ class ElementWiseTransform(LazyTransform):
)
>>> x = torch.randn(3)
>>> x
tensor([2.1983, -1.3182, 0.0329])
tensor([ 0.0303, 0.3644, -1.1831])
>>> c = torch.randn(4)
>>> y = t(c)(x)
>>> t(c).inv(y)
tensor([2.1983, -1.3182, 0.0329])
tensor([ 0.0303, 0.3644, -1.1831], grad_fn=<DivBackward0>)
"""

def __init__(
Expand Down
24 changes: 12 additions & 12 deletions zuko/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,10 @@ class MaskedMLP(nn.Sequential):
Example:
>>> adjacency = torch.randn(4, 3) < 0
>>> adjacency
tensor([[False, True, False],
[ True, False, True],
[False, True, False],
[False, True, True]])
tensor([[False, True, True],
[False, True, True],
[False, False, True],
[ True, True, False]])
>>> net = MaskedMLP(adjacency, [16, 32], activation=nn.ELU)
>>> net
MaskedMLP(
Expand All @@ -248,10 +248,10 @@ class MaskedMLP(nn.Sequential):
)
>>> x = torch.randn(3)
>>> torch.autograd.functional.jacobian(net, x)
tensor([[ 0.0000, 0.0031, 0.0000],
[-0.0323, 0.0000, -0.0547],
[ 0.0000, -0.0245, 0.0000],
[ 0.0000, 0.0060, -0.0063]])
tensor([[ 0.0000, -0.0065, 0.1158],
[ 0.0000, -0.0089, 0.0072],
[ 0.0000, 0.0000, 0.0089],
[-0.0146, -0.0128, 0.0000]])
"""

def __init__(
Expand Down Expand Up @@ -374,10 +374,10 @@ class MonotonicMLP(MLP):
)
>>> x = torch.randn(3)
>>> torch.autograd.functional.jacobian(net, x)
tensor([[0.8742, 0.9439, 0.9759],
[0.8969, 0.9716, 0.9866],
[1.0780, 1.1651, 1.2056],
[0.8596, 0.9400, 0.9502]])
tensor([[1.0492, 1.3094, 1.1711],
[1.1201, 1.3825, 1.2711],
[0.9397, 1.1915, 1.0787],
[1.1049, 1.3635, 1.2592]])
"""

def __init__(self, *args, **kwargs):
Expand Down
2 changes: 1 addition & 1 deletion zuko/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def odeint(
>>> x0 = torch.randn(3)
>>> x1 = odeint(f, x0, 0.0, 1.0)
>>> x1
tensor([-3.7454, -0.4140, 0.2677])
tensor([-1.4596, 0.5008, 1.5828])
"""

settings = (atol, rtol, torch.is_grad_enabled())
Expand Down

0 comments on commit 0852970

Please sign in to comment.