Skip to content

Commit

Permalink
Merge pull request #22 from JamesPerlman/patch-1
Browse files Browse the repository at this point in the history
Save result from jax.local_device_count()
  • Loading branch information
keunhong authored Feb 17, 2022
2 parents a371bc2 + 1f87f3f commit d433ebe
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion hypernerf/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def points_bounding_size(points):
def shard(xs, device_count=None):
"""Split data into shards for multiple devices along the first dimension."""
if device_count is None:
jax.local_device_count()
device_count = jax.local_device_count()
return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)


Expand Down

0 comments on commit d433ebe

Please sign in to comment.