diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py index 8760774fd90..fe814903d0c 100644 --- a/mmdet/core/utils/dist_utils.py +++ b/mmdet/core/utils/dist_utils.py @@ -11,6 +11,8 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) +from mmdet.utils import get_device + def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1): if bucket_size_mb > 0: @@ -154,7 +156,7 @@ def all_reduce_dict(py_dict, op='sum', group=None, to_float=True): return out_dict -def sync_random_seed(seed=None, device='cuda'): +def sync_random_seed(seed=None, device=None): """Make sure different ranks share the same seed. All workers must call this function, otherwise it will deadlock. @@ -171,13 +173,15 @@ def sync_random_seed(seed=None, device='cuda'): Args: seed (int, Optional): The seed. Default to None. device (str): The device where the seed will be put on. - Default to 'cuda'. + Default to None. Returns: int: Seed to be used. """ if seed is None: seed = np.random.randint(2**31) + if device is None: + device = get_device() assert isinstance(seed, int) rank, world_size = get_dist_info()