Skip to content

Commit

Permalink
[LSC] Ignore incorrect type annotations related to jax.numpy APIs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568283238
  • Loading branch information
Jake VanderPlas authored and The precondition Authors committed Sep 25, 2023
1 parent e059707 commit 6e5a628
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion precondition/distributed_shampoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -1583,7 +1583,7 @@ def updated_statistics_from_grad(
for axis in preconditioned_dims:
update = functools.partial(gram_weighted_update, precision=precision)
if frequent_directions:
if _should_compress(self._compression_rank, g.shape[axis]):
if _should_compress(self._compression_rank, g.shape[axis]): # pytype: disable=wrong-arg-types # jnp-type
update = frequent_directions_update
new_stat = update(to_float(stats[index]), g, axis, w1, w2)
new_stats.append(from_float(new_stat))
Expand Down
2 changes: 1 addition & 1 deletion precondition/tearfree/reallocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def score_fn(
score_dict[name] = jnp.mean(
jnp.array([ops_dict[rule](ct) for ct in current_target])
)
return score_dict
return score_dict # pytype: disable=bad-return-type # jnp-type


def create_redist_dict(
Expand Down

0 comments on commit 6e5a628

Please sign in to comment.