Skip to content

Commit

Permalink
🎨 Stricter linting rules
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Jun 6, 2024
1 parent 763a924 commit 7c05693
Show file tree
Hide file tree
Showing 15 changed files with 56 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pre-commit.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.2.2
rev: v0.4.8
hooks:
- id: ruff
- id: ruff-format
Expand Down
5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,12 @@ extend-include = ["*.ipynb"]
line-length = 99

[tool.ruff.lint]
extend-select = ["I"]
extend-select = ["I", "RUF022"]
ignore = ["E731", "E741"]
preview = true

[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F403"]
"__init__.py" = ["F401"]
"test_*.py" = ["F403", "F405"]

[tool.ruff.lint.isort]
Expand Down
18 changes: 9 additions & 9 deletions zuko/distributions.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
r"""Parameterizable probability distributions."""

__all__ = [
'NormalizingFlow',
'BoxUniform',
'DiagNormal',
'GeneralizedNormal',
'Joint',
'Maximum',
'Minimum',
'Mixture',
'GeneralizedNormal',
'DiagNormal',
'BoxUniform',
'TransformedUniform',
'Truncated',
'NormalizingFlow',
'Sort',
'TopK',
'Minimum',
'Maximum',
'TransformedUniform',
'Truncated',
]

import math
Expand Down Expand Up @@ -562,7 +562,7 @@ def log_prob(self, x: Tensor) -> Tensor:
return ordered.log() + self.log_fact + self.base.log_prob(x).sum(dim=0)

def sample(self, shape: Size = ()) -> Tensor:
x = torch.movedim(self.base.sample((self.n,) + shape), 0, -1)
x = torch.movedim(self.base.sample((self.n, *shape)), 0, -1)
x = torch.sort(x, dim=-1, descending=self.descending).values

return x
Expand Down
18 changes: 9 additions & 9 deletions zuko/flows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
r"""Parameterized flows and transformations."""

from .autoregressive import *
from .continuous import *
from .core import *
from .coupling import *
from .gaussianization import *
from .mixture import *
from .neural import *
from .polynomial import *
from .spline import *
from .autoregressive import MAF, MaskedAutoregressiveTransform
from .continuous import CNF, FFJTransform
from .core import Flow, LazyDistribution, LazyInverse, LazyTransform, Unconditional
from .coupling import NICE, GeneralCouplingTransform
from .gaussianization import GF, ElementWiseTransform
from .mixture import GMM
from .neural import NAF, UNAF
from .polynomial import BPF, SOSPF
from .spline import NCSF, NSF
4 changes: 2 additions & 2 deletions zuko/flows/autoregressive.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""Autoregressive flows and transformations."""

__all__ = [
'MaskedAutoregressiveTransform',
'MAF',
'MaskedAutoregressiveTransform',
]

import torch
Expand All @@ -13,7 +13,7 @@
from torch.distributions import Transform
from typing import Callable, Sequence

# isort: local
# isort: split
from .core import Flow, LazyTransform, Unconditional
from .gaussianization import ElementWiseTransform
from ..distributions import DiagNormal
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/continuous.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""Continuous flows and transformations."""

__all__ = [
'FFJTransform',
'CNF',
'FFJTransform',
]

import torch
Expand All @@ -13,7 +13,7 @@
from torch import Tensor
from torch.distributions import Transform

# isort: local
# isort: split
from .core import Flow, LazyTransform, Unconditional
from ..distributions import DiagNormal
from ..nn import MLP
Expand Down
6 changes: 3 additions & 3 deletions zuko/flows/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from __future__ import annotations

__all__ = [
'Flow',
'LazyComposedTransform',
'LazyDistribution',
'LazyTransform',
'LazyComposedTransform',
'Flow',
'Unconditional',
]

Expand All @@ -17,7 +17,7 @@
from torch.distributions import Distribution, Transform
from typing import Any, Callable, Sequence, Union

# isort: local
# isort: split
from ..distributions import NormalizingFlow
from ..transforms import ComposedTransform

Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/coupling.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""Coupling flows and transformations."""

__all__ = [
'GeneralCouplingTransform',
'NICE',
'GeneralCouplingTransform',
]

import torch
Expand All @@ -13,7 +13,7 @@
from torch.distributions import Transform
from typing import Callable, Sequence

# isort: local
# isort: split
from .core import Flow, LazyTransform, Unconditional
from .gaussianization import ElementWiseTransform
from ..distributions import DiagNormal
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/gaussianization.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
r"""Gaussianization flows."""

__all__ = [
'ElementWiseTransform',
'GF',
'ElementWiseTransform',
]

import torch
Expand All @@ -13,7 +13,7 @@
from torch.distributions import Transform
from typing import Callable, Sequence

# isort: local
# isort: split
from .core import Flow, LazyTransform, Unconditional
from ..distributions import DiagNormal
from ..nn import MLP
Expand Down
2 changes: 1 addition & 1 deletion zuko/flows/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from torch import Tensor
from torch.distributions import Distribution, MultivariateNormal

# isort: local
# isort: split
from .core import LazyDistribution
from ..distributions import Mixture
from ..nn import MLP
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

__all__ = [
'MNN',
'UMNN',
'NAF',
'UMNN',
'UNAF',
]

Expand All @@ -15,7 +15,7 @@
from torch.distributions import Transform
from typing import Any, Dict

# isort: local
# isort: split
from .autoregressive import MaskedAutoregressiveTransform
from .core import Flow, Unconditional
from ..distributions import DiagNormal
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/polynomial.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
r"""Polynomial flows."""

__all__ = [
'SOSPF',
'BPF',
'SOSPF',
]


# isort: local
# isort: split
from .autoregressive import MAF
from .core import Unconditional
from ..transforms import BoundedBernsteinTransform, SoftclipTransform, SOSPolynomialTransform
Expand Down
4 changes: 2 additions & 2 deletions zuko/flows/spline.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
r"""Spline flows."""

__all__ = [
'NSF',
'NCSF',
'NSF',
]

import torch

from math import pi
from torch.distributions import Transform

# isort: local
# isort: split
from .autoregressive import MAF
from .core import Unconditional
from ..distributions import BoxUniform
Expand Down
2 changes: 1 addition & 1 deletion zuko/nn.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
r"""Neural networks, layers and modules."""

__all__ = ['Linear', 'MLP', 'MaskedMLP', 'MonotonicMLP']
__all__ = ['MLP', 'Linear', 'MaskedMLP', 'MonotonicMLP']

import torch
import torch.nn as nn
Expand Down
30 changes: 15 additions & 15 deletions zuko/transforms.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
r"""Parameterizable transformations."""

__all__ = [
'AutoregressiveTransform',
'BernsteinTransform',
'BoundedBernsteinTransform',
'CircularShiftTransform',
'ComposedTransform',
'CosTransform',
'CouplingTransform',
'DependentTransform',
'FreeFormJacobianTransform',
'GaussianizationTransform',
'IdentityTransform',
'CosTransform',
'SinTransform',
'SoftclipTransform',
'CircularShiftTransform',
'SignedPowerTransform',
'LULinearTransform',
'MonotonicAffineTransform',
'MonotonicRQSTransform',
'MonotonicTransform',
'BernsteinTransform',
'BoundedBernsteinTransform',
'GaussianizationTransform',
'UnconstrainedMonotonicTransform',
'SOSPolynomialTransform',
'AutoregressiveTransform',
'CouplingTransform',
'FreeFormJacobianTransform',
'PermutationTransform',
'RotationTransform',
'LULinearTransform',
'SOSPolynomialTransform',
'SignedPowerTransform',
'SinTransform',
'SoftclipTransform',
'UnconstrainedMonotonicTransform',
]

import math
Expand All @@ -36,7 +36,7 @@
from torch.distributions.utils import _sum_rightmost
from typing import Any, Callable, Iterable, Tuple, Union

# isort: local
# isort: split
from .utils import bisection, broadcast, gauss_legendre, odeint

torch.distributions.transforms._InverseTransform.__name__ = 'Inverse'
Expand Down

0 comments on commit 7c05693

Please sign in to comment.