Skip to content

Commit

Permalink
Update to pass tests with NumPy v2.0.1.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675108871
  • Loading branch information
jpuigcerver authored and copybara-github committed Sep 16, 2024
1 parent 389af94 commit a13fe02
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions vmoe/multihost_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def sync_devices(name: str, main_process: int = 0):
# All devices will be initialized with the value 0, except the first device
# of the `main_process`, which will be initiaized with the CRC32 of the
# `name`.
h = np.int32(zlib.crc32(name.encode()))
x = np.zeros(jax.local_device_count(), dtype=np.int32)
h = np.uint32(zlib.crc32(name.encode()))
x = np.zeros(jax.local_device_count(), dtype=np.uint32)
if jax.process_index() == main_process:
x[0] = h
# The values in all devices are summed. Thus, the result in all processes
Expand Down

0 comments on commit a13fe02

Please sign in to comment.