diff --git a/vmoe/multihost_utils.py b/vmoe/multihost_utils.py index a04dc7d..0474911 100644 --- a/vmoe/multihost_utils.py +++ b/vmoe/multihost_utils.py @@ -32,8 +32,8 @@ def sync_devices(name: str, main_process: int = 0): # All devices will be initialized with the value 0, except the first device # of the `main_process`, which will be initiaized with the CRC32 of the # `name`. - h = np.int32(zlib.crc32(name.encode())) - x = np.zeros(jax.local_device_count(), dtype=np.int32) + h = np.uint32(zlib.crc32(name.encode())) + x = np.zeros(jax.local_device_count(), dtype=np.uint32) if jax.process_index() == main_process: x[0] = h # The values in all devices are summed. Thus, the result in all processes