Skip to content

Commit

Permalink
jax.numpy.clip: update use of deprecated arguments.
Browse files Browse the repository at this point in the history
- a is now positional-only
- a_min is now min
- a_max is now max

The old argument names have been deprecated since JAX v0.4.27.

PiperOrigin-RevId: 715343439
Change-Id: I50b086b249360c142f42ed4d8e50d48692c11dfa
  • Loading branch information
Jake VanderPlas authored and copybara-github committed Jan 14, 2025
1 parent 6350ddd commit fa0c526
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions alphafold/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -1854,8 +1854,8 @@ def __call__(self, batch, is_training, safe_key=None):
rel_pos = jax.nn.one_hot(
jnp.clip(
offset + c.max_relative_feature,
a_min=0,
a_max=2 * c.max_relative_feature),
min=0,
max=2 * c.max_relative_feature),
2 * c.max_relative_feature + 1)
pair_activations += common_modules.Linear(
c.pair_channel, name='pair_activiations')(
Expand Down
4 changes: 2 additions & 2 deletions alphafold/model/modules_multimer.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,7 +564,7 @@ def _relative_encoding(self, batch):
dtype = jnp.bfloat16 if gc.bfloat16 else jnp.float32

clipped_offset = jnp.clip(
offset + c.max_relative_idx, a_min=0, a_max=2 * c.max_relative_idx)
offset + c.max_relative_idx, min=0, max=2 * c.max_relative_idx)

if c.use_chain_relative:

Expand All @@ -586,7 +586,7 @@ def _relative_encoding(self, batch):
max_rel_chain = c.max_relative_chain

clipped_rel_chain = jnp.clip(
rel_sym_id + max_rel_chain, a_min=0, a_max=2 * max_rel_chain)
rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain)

final_rel_chain = jnp.where(entity_id_same, clipped_rel_chain,
(2 * max_rel_chain + 1) *
Expand Down

0 comments on commit fa0c526

Please sign in to comment.