diff --git a/byol_pytorch/trainer.py b/byol_pytorch/trainer.py index eb1ffec0a..9ca64cc4a 100644 --- a/byol_pytorch/trainer.py +++ b/byol_pytorch/trainer.py @@ -14,6 +14,13 @@ from beartype.typing import Optional from accelerate import Accelerator +from accelerate.utils import DistributedDataParallelKwargs + +# constants + +DEFAULT_DDP_KWARGS = DistributedDataParallelKwargs( + find_unused_parameters = True +) # functions @@ -60,6 +67,10 @@ def __init__( accelerator_kwargs: dict = dict(), ): super().__init__() + + if 'kwargs_handlers' not in accelerator_kwargs: + accelerator_kwargs['kwargs_handlers'] = [DEFAULT_DDP_KWARGS] + self.accelerator = Accelerator(**accelerator_kwargs) if dist.is_initialized() and dist.get_world_size() > 1: diff --git a/setup.py b/setup.py index 96cc77df2..6350477a1 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'byol-pytorch', packages = find_packages(exclude=['examples']), - version = '0.8.0', + version = '0.8.1', license='MIT', description = 'Self-supervised contrastive learning made simple', author = 'Phil Wang',