Skip to content

Commit

Permalink
remove bn sync in imagenet (jit handles it automatically)
Browse files Browse the repository at this point in the history
  • Loading branch information
Your Name committed Feb 5, 2025
1 parent be9a68a commit 95ab984
Showing 1 changed file with 2 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def is_output_params(self, param_key: spec.ParameterKey) -> bool:
sharding_utils.get_replicated_sharding(), # rng
),
static_argnums=(0,),
out_shardings=sharding_utils.get_naive_sharding_spec())
out_shardings=sharding_utils.get_replicated_sharding())
def _eval_model(self,
params: spec.ParameterContainer,
batch: Dict[str, spec.Tensor],
Expand Down Expand Up @@ -245,9 +245,6 @@ def _eval_model_on_split(self,
data_dir: str,
global_step: int = 0) -> Dict[str, float]:
del global_step
if model_state is not None:
# Sync batch statistics across replicas before evaluating.
model_state = self.sync_batch_stats(model_state)
num_batches = int(math.ceil(num_examples / global_batch_size))
data_rng, eval_rng = prng.split(rng, 2)
# We already repeat the dataset indefinitely in tf.data.
Expand All @@ -270,14 +267,12 @@ def _eval_model_on_split(self,
batch,
model_state,
step_eval_rngs)
# Sum up the synced metrics
synced_metrics = jax.tree_map(lambda x: jnp.sum(x, axis=0), synced_metrics)
for metric_name, metric_value in synced_metrics.items():
if metric_name not in eval_metrics:
eval_metrics[metric_name] = 0.0
eval_metrics[metric_name] += metric_value

eval_metrics = jax.tree_map(lambda x: float(x[0] / num_examples),
eval_metrics = jax.tree_map(lambda x: x / num_examples,
eval_metrics)
return eval_metrics

Expand Down

0 comments on commit 95ab984

Please sign in to comment.