diff --git a/vmoe/multihost_utils.py b/vmoe/multihost_utils.py
index a04dc7d..0474911 100644
--- a/vmoe/multihost_utils.py
+++ b/vmoe/multihost_utils.py
@@ -32,8 +32,8 @@ def sync_devices(name: str, main_process: int = 0):
   # All devices will be initialized with the value 0, except the first device
   # of the `main_process`, which will be initiaized with the CRC32 of the
   # `name`.
-  h = np.int32(zlib.crc32(name.encode()))
-  x = np.zeros(jax.local_device_count(), dtype=np.int32)
+  h = np.uint32(zlib.crc32(name.encode()))
+  x = np.zeros(jax.local_device_count(), dtype=np.uint32)
   if jax.process_index() == main_process:
     x[0] = h
   # The values in all devices are summed. Thus, the result in all processes