From 1f87f3fd1f9fdada1933fd644021d8adc6384945 Mon Sep 17 00:00:00 2001 From: JamesPerlman Date: Thu, 17 Feb 2022 00:43:37 -0800 Subject: [PATCH] Save result from jax.local_device_count() Identical issue to https://github.com/google/nerfies/pull/47 --- hypernerf/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hypernerf/utils.py b/hypernerf/utils.py index 5060a60..6e7e97f 100644 --- a/hypernerf/utils.py +++ b/hypernerf/utils.py @@ -285,7 +285,7 @@ def points_bounding_size(points): def shard(xs, device_count=None): """Split data into shards for multiple devices along the first dimension.""" if device_count is None: - jax.local_device_count() + device_count = jax.local_device_count() return jax.tree_map(lambda x: x.reshape((device_count, -1) + x.shape[1:]), xs)