diff --git a/precondition/distributed_shampoo.py b/precondition/distributed_shampoo.py index 2542477..ac8465e 100644 --- a/precondition/distributed_shampoo.py +++ b/precondition/distributed_shampoo.py @@ -1583,7 +1583,7 @@ def updated_statistics_from_grad( for axis in preconditioned_dims: update = functools.partial(gram_weighted_update, precision=precision) if frequent_directions: - if _should_compress(self._compression_rank, g.shape[axis]): + if _should_compress(self._compression_rank, g.shape[axis]): # pytype: disable=wrong-arg-types # jnp-type update = frequent_directions_update new_stat = update(to_float(stats[index]), g, axis, w1, w2) new_stats.append(from_float(new_stat)) diff --git a/precondition/tearfree/reallocation.py b/precondition/tearfree/reallocation.py index a9dac94..591d083 100644 --- a/precondition/tearfree/reallocation.py +++ b/precondition/tearfree/reallocation.py @@ -154,7 +154,7 @@ def score_fn( score_dict[name] = jnp.mean( jnp.array([ops_dict[rule](ct) for ct in current_target]) ) - return score_dict + return score_dict # pytype: disable=bad-return-type # jnp-type def create_redist_dict(