Skip to content

Commit

Permalink
[Fix] sync_random_seed func adapts to different devices (#10490)
Browse files Browse the repository at this point in the history
  • Loading branch information
Ginray committed Jun 13, 2023
1 parent 79fc6ab commit 8d28a08
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mmdet/core/utils/dist_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)

from mmdet.utils import get_device


def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
if bucket_size_mb > 0:
Expand Down Expand Up @@ -154,7 +156,7 @@ def all_reduce_dict(py_dict, op='sum', group=None, to_float=True):
return out_dict


def sync_random_seed(seed=None, device='cuda'):
def sync_random_seed(seed=None, device=None):
"""Make sure different ranks share the same seed.
All workers must call this function, otherwise it will deadlock.
Expand All @@ -171,13 +173,15 @@ def sync_random_seed(seed=None, device='cuda'):
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Default to None.
Returns:
int: Seed to be used.
"""
if seed is None:
seed = np.random.randint(2**31)
if device is None:
device = get_device()
assert isinstance(seed, int)

rank, world_size = get_dist_info()
Expand Down

0 comments on commit 8d28a08

Please sign in to comment.