From a13fe0200b3f2dcf5d897d5b53000fa021ca0222 Mon Sep 17 00:00:00 2001 From: Joan Puigcerver Date: Mon, 16 Sep 2024 05:29:38 -0700 Subject: [PATCH] Update to pass tests with NumPy v2.0.1. PiperOrigin-RevId: 675108871 --- vmoe/multihost_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vmoe/multihost_utils.py b/vmoe/multihost_utils.py index a04dc7d..0474911 100644 --- a/vmoe/multihost_utils.py +++ b/vmoe/multihost_utils.py @@ -32,8 +32,8 @@ def sync_devices(name: str, main_process: int = 0): # All devices will be initialized with the value 0, except the first device # of the `main_process`, which will be initiaized with the CRC32 of the # `name`. - h = np.int32(zlib.crc32(name.encode())) - x = np.zeros(jax.local_device_count(), dtype=np.int32) + h = np.uint32(zlib.crc32(name.encode())) + x = np.zeros(jax.local_device_count(), dtype=np.uint32) if jax.process_index() == main_process: x[0] = h # The values in all devices are summed. Thus, the result in all processes