Skip to content

Commit 7f5acde

Browse files
authored
Enable compatibility with torchmetrics >= 0.6.0 (#27)
1 parent 47a71b7 commit 7f5acde

File tree

12 files changed

+59
-35
lines changed

12 files changed

+59
-35
lines changed

README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,18 @@ class. It's
8686
[**init**](https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.trainer.trainer.html#pytorch_lightning.trainer.trainer.Trainer.__init__)
8787
method provides various configuration options.
8888

89-
If you want to run K-Means with a GPU, you can pass the option `gpus=1` to the estimator's
90-
initializer:
89+
If you want to run K-Means with a GPU, you can pass the options `accelerator='gpu'` and `devices=1`
90+
to the estimator's initializer:
9191

9292
```python
93-
estimator = KMeans(3, trainer_params=dict(gpus=1))
93+
estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1))
9494
```
9595

9696
Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available,
9797
you can specify this as follows:
9898

9999
```python
100-
estimator = KMeans(3, trainer_params=dict(num_nodes=4, gpus=1))
100+
estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', devices=1))
101101
```
102102

103103
In fact, **you do not need to change anything else in your code**.

docs/index.rst

+3-3
Original file line numberDiff line numberDiff line change
@@ -78,19 +78,19 @@ For GPU- and multi-node training, PyCave leverages PyTorch Lightning. The hardwa
7878
runs on is determined by the :class:`pytorch_lightning.trainer.Trainer` class. It's
7979
:meth:`~pytorch_lightning.trainer.Trainer.__init__` method provides various configuration options.
8080

81-
If you want to run K-Means with a GPU, you can pass the option ``gpus=1`` to the estimator's
81+
If you want to run K-Means with a GPU, you can pass the option ``accelerator='gpu'`` and ``devices=1`` to the estimator's
8282
initializer:
8383

8484
.. code-block:: python
8585
86-
estimator = KMeans(3, trainer_params=dict(gpus=1))
86+
estimator = KMeans(3, trainer_params=dict(accelerator='gpu', devices=1))
8787
8888
Similarly, if you want to train on 4 nodes simultaneously where each node has one GPU available,
8989
you can specify this as follows:
9090

9191
.. code-block:: python
9292
93-
estimator = KMeans(3, trainer_params=dict(num_nodes=4, gpus=1))
93+
estimator = KMeans(3, trainer_params=dict(num_nodes=4, accelerator='gpu', 1))
9494
9595
In fact, **you do not need to change anything else in your code**.
9696

poetry.lock

+23-17
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pycave/bayes/gmm/lightning_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import pytorch_lightning as pl
33
import torch
44
from pytorch_lightning.callbacks import EarlyStopping
5-
from torchmetrics import AverageMeter
5+
from torchmetrics import MeanMetric
66
from pycave.bayes.core import cholesky_precision
77
from pycave.utils import NonparametricLightningModule
88
from .metrics import CovarianceAggregator, MeanAggregator, PriorAggregator
@@ -65,7 +65,7 @@ def __init__(
6565
)
6666

6767
# Initialize metrics
68-
self.metric_nll = AverageMeter(dist_sync_fn=self.all_gather)
68+
self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather)
6969

7070
def configure_callbacks(self) -> list[pl.Callback]:
7171
if self.convergence_tolerance == 0:

pycave/bayes/gmm/metrics.py

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class PriorAggregator(Metric):
99
The prior aggregator aggregates component probabilities over batches and process.
1010
"""
1111

12+
full_state_update = False
13+
1214
def __init__(
1315
self,
1416
num_components: int,
@@ -33,6 +35,8 @@ class MeanAggregator(Metric):
3335
The mean aggregator aggregates component means over batches and processes.
3436
"""
3537

38+
full_state_update = False
39+
3640
def __init__(
3741
self,
3842
num_components: int,
@@ -63,6 +67,8 @@ class CovarianceAggregator(Metric):
6367
The covariance aggregator aggregates component covariances over batches and processes.
6468
"""
6569

70+
full_state_update = False
71+
6672
def __init__(
6773
self,
6874
num_components: int,

pycave/bayes/markov_chain/lightning_module.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch.nn.utils.rnn import PackedSequence
3-
from torchmetrics import AverageMeter
3+
from torchmetrics import MeanMetric
44
from pycave.bayes.markov_chain.metrics import StateCountAggregator
55
from pycave.utils import NonparametricLightningModule
66
from .model import MarkovChainModel
@@ -27,7 +27,7 @@ def __init__(self, model: MarkovChainModel, symmetric: bool = False):
2727
symmetric=self.symmetric,
2828
dist_sync_fn=self.all_gather,
2929
)
30-
self.metric_nll = AverageMeter(dist_sync_fn=self.all_gather)
30+
self.metric_nll = MeanMetric(dist_sync_fn=self.all_gather)
3131

3232
def on_train_epoch_start(self) -> None:
3333
self.aggregator.reset()

pycave/bayes/markov_chain/metrics.py

+2
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class StateCountAggregator(Metric):
99
The state count aggregator aggregates initial states and transitions between states.
1010
"""
1111

12+
full_state_update = False
13+
1214
def __init__(
1315
self,
1416
num_states: int,

pycave/clustering/kmeans/lightning_module.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pytorch_lightning as pl
55
import torch
66
from pytorch_lightning.callbacks import EarlyStopping
7-
from torchmetrics import AverageMeter
7+
from torchmetrics import MeanMetric
88
from pycave.utils import NonparametricLightningModule
99
from .metrics import (
1010
BatchAverager,
@@ -51,7 +51,7 @@ def __init__(
5151
)
5252

5353
# Initialize metrics
54-
self.metric_inertia = AverageMeter()
54+
self.metric_inertia = MeanMetric()
5555

5656
def configure_callbacks(self) -> List[pl.Callback]:
5757
if self.convergence_tolerance == 0:
@@ -239,8 +239,8 @@ def nonparametric_training_step(self, batch: torch.Tensor, batch_idx: int) -> No
239239

240240
def nonparametric_training_epoch_end(self) -> None:
241241
if self.current_epoch == 0:
242-
choice = self.uniform_sampler.compute()[0]
243-
self.model.centroids[0].copy_(choice)
242+
choice = self.uniform_sampler.compute()
243+
self.model.centroids[0].copy_(choice[0] if choice.dim() > 0 else choice)
244244
elif self._is_current_epoch_sampling:
245245
candidates = self.distance_sampler.compute()
246246
self.centroid_candidates.copy_(candidates)

pycave/clustering/kmeans/metrics.py

+10
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ class CentroidAggregator(Metric):
99
The centroid aggregator aggregates kmeans centroids over batches and processes.
1010
"""
1111

12+
full_state_update = False
13+
1214
def __init__(
1315
self,
1416
num_clusters: int,
@@ -49,6 +51,8 @@ class UniformSampler(Metric):
4951
they were already sampled from).
5052
"""
5153

54+
full_state_update = False
55+
5256
def __init__(
5357
self,
5458
num_choices: int,
@@ -109,6 +113,8 @@ class DistanceSampler(Metric):
109113
duplicates.
110114
"""
111115

116+
full_state_update = False
117+
112118
def __init__(
113119
self,
114120
num_choices: int,
@@ -169,6 +175,8 @@ class BatchSummer(Metric):
169175
Sums the values for a batch of items independently.
170176
"""
171177

178+
full_state_update = True
179+
172180
def __init__(self, num_values: int, *, dist_sync_fn: Optional[Callable[[Any], Any]] = None):
173181
super().__init__(dist_sync_fn=dist_sync_fn) # type: ignore
174182

@@ -187,6 +195,8 @@ class BatchAverager(Metric):
187195
Averages the values for a batch of items independently.
188196
"""
189197

198+
full_state_update = False
199+
190200
def __init__(
191201
self,
192202
num_values: int,

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ numpy = "^1.20.3"
1818
python = ">=3.8,<3.11"
1919
pytorch-lightning = "^1.6.0"
2020
torch = "^1.8.0"
21-
torchmetrics = "^0.5.1,<0.6.0"
21+
torchmetrics = "^0.6.0"
2222

2323
[tool.poetry.group.pre-commit.dependencies]
2424
black = "^22.12.0"

tests/bayes/gmm/benchmark_gmm_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,6 @@ def test_pycave_gpu(
128128
convergence_tolerance=0,
129129
covariance_regularization=1e-3,
130130
batch_size=batch_size,
131-
trainer_params=dict(max_epochs=100, gpus=1),
131+
trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1),
132132
)
133133
benchmark(estimator.fit, data)

tests/clustering/kmeans/benchmark_kmeans_estimator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,6 @@ def test_pycave_gpu(
120120
init_strategy=init_strategy,
121121
batch_size=batch_size,
122122
convergence_tolerance=0,
123-
trainer_params=dict(gpus=1, max_epochs=100),
123+
trainer_params=dict(max_epochs=100, accelerator="gpu", devices=1),
124124
)
125125
benchmark(estimator.fit, data)

0 commit comments

Comments
 (0)