Skip to content

Commit

Permalink
✨ Add UnconditionalDistribution and UnconditionalTransform
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jun 19, 2024
1 parent b3a1099 commit 110757f
Show file tree
Hide file tree
Showing 12 changed files with 243 additions and 43 deletions.
7 changes: 4 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,17 +55,18 @@ x = flow(c_star).sample((64,))
Alternatively, flows can be built as custom `Flow` objects.

```python
from zuko.flows import Flow, MaskedAutoregressiveTransform, Unconditional
from zuko.flows import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform

flow = Flow(
transform=[
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
Unconditional(RotationTransform, torch.randn(3, 3)),
UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
],
base=Unconditional(
base=UnconditionalDistribution(
DiagNormal,
torch.zeros(3),
torch.ones(3),
Expand Down
7 changes: 4 additions & 3 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,17 +58,18 @@ Alternatively, flows can be built as custom :class:`zuko.flows.core.Flow` object

.. code-block:: python
from zuko.flows import Flow, MaskedAutoregressiveTransform, Unconditional
from zuko.flows import Flow, UnconditionalDistribution, UnconditionalTransform
from zuko.flows.autoregressive import MaskedAutoregressiveTransform
from zuko.distributions import DiagNormal
from zuko.transforms import RotationTransform
flow = Flow(
transform=[
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
Unconditional(RotationTransform, torch.randn(3, 3)),
UnconditionalTransform(RotationTransform, torch.randn(3, 3)),
MaskedAutoregressiveTransform(3, 5, hidden_features=(64, 64)),
],
base=Unconditional(
base=UnconditionalDistribution(
DiagNormal,
torch.zeros(3),
torch.ones(3),
Expand Down
9 changes: 8 additions & 1 deletion zuko/flows/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,14 @@

from .autoregressive import MAF, MaskedAutoregressiveTransform
from .continuous import CNF, FFJTransform
from .core import Flow, LazyDistribution, LazyInverse, LazyTransform, Unconditional
from .core import (
Flow,
LazyDistribution,
LazyInverse,
LazyTransform,
UnconditionalDistribution,
UnconditionalTransform,
)
from .coupling import NICE, GeneralCouplingTransform
from .gaussianization import GF, ElementWiseTransform
from .mixture import GMM
Expand Down
6 changes: 3 additions & 3 deletions zuko/flows/autoregressive.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Sequence

# isort: split
from .core import Flow, LazyTransform, Unconditional
from .core import Flow, LazyTransform, UnconditionalDistribution
from .gaussianization import ElementWiseTransform
from ..distributions import DiagNormal
from ..nn import MaskedMLP
Expand Down Expand Up @@ -200,7 +200,7 @@ class MAF(Flow):
)
)
)
(base): Unconditional(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])))
(base): UnconditionalDistribution(DiagNormal(loc: torch.Size([3]), scale: torch.Size([3])))
)
>>> c = torch.randn(4)
>>> x = flow(c).sample()
Expand Down Expand Up @@ -233,7 +233,7 @@ def __init__(
for i in range(transforms)
]

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torch.distributions import Transform

# isort: split
from .core import Flow, LazyTransform, Unconditional
from .core import Flow, LazyTransform, UnconditionalDistribution
from ..distributions import DiagNormal
from ..nn import MLP
from ..transforms import FreeFormJacobianTransform
Expand Down Expand Up @@ -138,7 +138,7 @@ def __init__(
**kwargs,
)

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down
129 changes: 113 additions & 16 deletions zuko/flows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,13 @@
'LazyDistribution',
'LazyTransform',
'Unconditional',
'UnconditionalDistribution',
'UnconditionalTransform',
]

import abc
import torch.nn as nn
import warnings

from torch import Tensor
from torch.distributions import Distribution, Transform
Expand All @@ -20,6 +23,7 @@
# isort: split
from ..distributions import NormalizingFlow
from ..transforms import ComposedTransform
from ..utils import Partial


class LazyDistribution(nn.Module, abc.ABC):
Expand Down Expand Up @@ -174,27 +178,15 @@ class Unconditional(nn.Module):
Typically, the constructor returns a distribution or transformation. The positional
arguments of the constructor are registered as buffers or parameters.
Warning:
:class:`Unconditional` is deprecated and will be removed in the future. Use
:class:`UnconditionalDistribution` or :class:`UnconditionalTransform` instead.
Arguments:
meta: An arbitrary constructor function.
args: The positional tensor arguments passed to `meta`.
buffer: Whether tensors are registered as buffers or parameters.
kwargs: The keyword arguments passed to `meta`.
Examples:
>>> f = zuko.distributions.DiagNormal
>>> mu, sigma = torch.zeros(3), torch.ones(3)
>>> d = Unconditional(f, mu, sigma, buffer=True)
>>> d()
DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> d().sample()
tensor([ 1.5410, -0.2934, -2.1788])
>>> t = Unconditional(zuko.transforms.ExpTransform)
>>> t()
ExpTransform()
>>> x = torch.randn(3)
>>> t()(x)
tensor([1.7655, 0.3381, 0.2469])
"""

def __init__(
Expand All @@ -206,6 +198,15 @@ def __init__(
):
super().__init__()

warnings.warn(
(
"'Unconditional' is deprecated and will be removed in the future. "
"Use 'UnconditionalDistribution' or 'UnconditionalTransform' instead."
),
category=DeprecationWarning,
stacklevel=2,
)

self.meta = meta

for i, arg in enumerate(args):
Expand Down Expand Up @@ -236,3 +237,99 @@ def forward(self, c: Tensor = None) -> Any:
*self._buffers.values(),
**self.kwargs,
)


class UnconditionalDistribution(Partial, LazyDistribution):
r"""Creates an unconditional lazy distribution from a constructor.
The arguments of the constructor are registered as buffers or parameters.
Arguments:
f: A distribution constructor. If `f` is a module, it is registered as a submodule.
args: The positional arguments passed to `f`.
buffer: Whether tensor arguments are registered as buffers or parameters.
kwargs: The keyword arguments passed to `f`.
Examples:
>>> f = zuko.distributions.DiagNormal
>>> mu, sigma = torch.zeros(3), torch.ones(3)
>>> base = UnconditionalDistribution(f, mu, sigma, buffer=True)
>>> base()
DiagNormal(loc: torch.Size([3]), scale: torch.Size([3]))
>>> base().sample()
tensor([ 1.5410, -0.2934, -2.1788])
"""

def __init__(
self,
f: Callable[..., Distribution],
*args: Any,
buffer: bool = False,
**kwargs,
):
super().__init__(f, *args, buffer=buffer, **kwargs)

def extra_repr(self) -> str:
if isinstance(self.f, nn.Module):
return ''
else:
return repr(self.forward())

def forward(self, c: Tensor = None) -> Distribution:
r"""
Arguments:
c: A context :math:`c`. This argument is always ignored.
Returns:
:py:`self.f(*self.args, **self.kwargs)`
"""

return super().forward()


class UnconditionalTransform(Partial, LazyTransform):
r"""Creates an unconditional lazy transformation from a constructor.
The arguments of the constructor are registered as buffers or parameters.
Arguments:
f: A transformation constructor. If `f` is a module, it is registered as a submodule.
args: The positional arguments passed to `f`.
buffer: Whether tensor arguments are registered as buffers or parameters.
kwargs: The keyword arguments passed to `f`.
Examples:
>>> f = zuko.transforms.ExpTransform
>>> t = UnconditionalTransform(f)
>>> t()
ExpTransform()
>>> x = torch.randn(3)
>>> t()(x)
tensor([4.6692, 0.7457, 0.1132])
"""

def __init__(
self,
f: Callable[..., Transform],
*args: Any,
buffer: bool = False,
**kwargs: Any,
):
super().__init__(f, *args, buffer=buffer, **kwargs)

def extra_repr(self) -> str:
if isinstance(self.f, nn.Module):
return ''
else:
return repr(self.forward())

def forward(self, c: Tensor = None) -> Transform:
r"""
Arguments:
c: A context :math:`c`. This argument is always ignored.
Returns:
:py:`self.f(*self.args, **self.kwargs)`
"""

return super().forward()
4 changes: 2 additions & 2 deletions zuko/flows/coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Sequence

# isort: split
from .core import Flow, LazyTransform, Unconditional
from .core import Flow, LazyTransform, UnconditionalDistribution
from .gaussianization import ElementWiseTransform
from ..distributions import DiagNormal
from ..nn import MLP
Expand Down Expand Up @@ -177,7 +177,7 @@ def __init__(
)
)

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down
6 changes: 3 additions & 3 deletions zuko/flows/gaussianization.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from typing import Callable, Sequence

# isort: split
from .core import Flow, LazyTransform, Unconditional
from .core import Flow, LazyTransform, UnconditionalDistribution, UnconditionalTransform
from ..distributions import DiagNormal
from ..nn import MLP
from ..transforms import (
Expand Down Expand Up @@ -140,13 +140,13 @@ def __init__(
for i in reversed(range(1, len(transforms))):
transforms.insert(
i,
Unconditional(
UnconditionalTransform(
RotationTransform,
torch.randn(features, features),
),
)

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down
10 changes: 5 additions & 5 deletions zuko/flows/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

# isort: split
from .autoregressive import MaskedAutoregressiveTransform
from .core import Flow, Unconditional
from .core import Flow, UnconditionalDistribution, UnconditionalTransform
from ..distributions import DiagNormal
from ..nn import MLP, MonotonicMLP
from ..transforms import (
Expand Down Expand Up @@ -157,9 +157,9 @@ def __init__(
]

for i in reversed(range(1, len(transforms))):
transforms.insert(i, Unconditional(SoftclipTransform, bound=11.0))
transforms.insert(i, UnconditionalTransform(SoftclipTransform, bound=11.0))

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down Expand Up @@ -221,9 +221,9 @@ def __init__(
]

for i in reversed(range(1, len(transforms))):
transforms.insert(i, Unconditional(SoftclipTransform, bound=11.0))
transforms.insert(i, UnconditionalTransform(SoftclipTransform, bound=11.0))

base = Unconditional(
base = UnconditionalDistribution(
DiagNormal,
torch.zeros(features),
torch.ones(features),
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/polynomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

# isort: split
from .autoregressive import MAF
from .core import Unconditional
from .core import UnconditionalTransform
from ..transforms import BoundedBernsteinTransform, SoftclipTransform, SOSPolynomialTransform


Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
transforms = self.transform.transforms

for i in reversed(range(1, len(transforms))):
transforms.insert(i, Unconditional(SoftclipTransform, bound=11.0))
transforms.insert(i, UnconditionalTransform(SoftclipTransform, bound=11.0))


class BPF(MAF):
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

# isort: split
from .autoregressive import MAF
from .core import Unconditional
from .core import UnconditionalDistribution
from ..distributions import BoxUniform
from ..transforms import CircularShiftTransform, ComposedTransform, MonotonicRQSTransform

Expand Down Expand Up @@ -104,7 +104,7 @@ def __init__(
**kwargs,
)

self.base = Unconditional(
self.base = UnconditionalDistribution(
BoxUniform,
torch.full((features,), -pi - 1e-5),
torch.full((features,), pi + 1e-5),
Expand Down
Loading

0 comments on commit 110757f

Please sign in to comment.