Skip to content

Releases: jax-ml/jax

JAX release v0.3.8

30 Apr 03:09
Compare
Choose a tag to compare
  • GitHub commits.
  • Changes
    • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
    • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
    • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
    • {func}jax.scipy.cluster.vq.vq has been added.
    • jax.experimental.maps.mesh has been deleted.
      Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
      for more information.
    • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
    • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
    • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
    • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
    • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of
      {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
    • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
    • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
    • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
    • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
    • Added {func}jax.random.orthogonal.
  • Deprecations
    • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.

Jaxlib v0.3.7

29 Apr 18:18
Compare
Choose a tag to compare
  • Linux wheels are now built conforming to the manylinux2014 standard, instead of manylinux2010.

JAX release v0.3.7

29 Apr 18:09
Compare
Choose a tag to compare
  • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
  • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
  • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).

JAX release v0.3.6

13 Apr 00:52
Compare
Choose a tag to compare
  • Changes:
    • Upgraded libtpu wheel to the fixed version. Fixes #10218.

JAX release v0.3.5

07 Apr 20:29
Compare
Choose a tag to compare

Changes

  • added jax.random.loggamma & improved behavior of jax.random.beta
    and jax.random.dirichlet for small parameter values (#9906).
  • the private lax_numpy submodule is no longer exposed in the jax.numpy namespace (#10029).
  • added array creation routines jax.numpy.frombuffer, jax.numpy.fromfunction,
    and jax.numpy.fromstring (#10049).
  • DeviceArray.copy() now returns a DeviceArray rather than a np.ndarray (#10069)
  • added jax.scipy.linalg.rsf2csf
  • Deprecations:
    • jax.nn.normalize is being deprecated. Use jax.nn.standardize instead (#9899).
    • jax.tree_util.tree_multimap is deprecated. Use jax.tree_util.tree_map instead (#5746).
    • jax.experimental.sharded_jit is deprecated. Use pjit instead.

JAX release v0.3.4

18 Mar 21:13
Compare
Choose a tag to compare

Fix a bug introduced in #9923.

JAX release v0.3.3

17 Mar 22:31
Compare
Choose a tag to compare

Jax release v0.3.1

18 Feb 22:36
Compare
Choose a tag to compare
  • Changes:
    • jax.test_util.JaxTestCase and jax.test_util.JaxTestLoader are now deprecated.
      The suggested replacement is to use parametrized.TestCase directly. For tests that
      rely on custom asserts such as JaxTestCase.assertAllClose(), the suggested replacement
      is to use standard numpy testing utilities such as numpy.testing.assert_allclose(),
      which work directly with JAX arrays (#9620 ).
    • jax.test_util.JaxTestCase now sets jax_numpy_rank_promotion='raise' by default
      (#9562 ). To recover the previous behavior, use the new
      jax.test_util.with_config decorator:
      @jtu.with_config(jax_numpy_rank_promotion='allow')
      class MyTestCase(jtu.JaxTestCase):
        ...
    • Added jax.scipy.linalg.schur, jax.scipy.linalg.sqrtm,
      jax.scipy.signal.csd, jax.scipy.signal.stft,
      jax.scipy.signal.welch.

Jaxlib release v0.3.0

10 Feb 20:07
Compare
Choose a tag to compare
  • Changes
    • Bazel 5.0.0 is now required to build jaxlib.
    • jaxlib version has been bumped to 0.3.0. Please see the design doc
      for the explanation.

Jax release v0.3.0

10 Feb 20:07
Compare
Choose a tag to compare
  • Changes
    • jax version has been bumped to 0.3.0. Please see the design doc
      for the explanation.