Releases: jax-ml/jax
Releases · jax-ml/jax
JAX release v0.3.8
- GitHub commits.
- Changes
- {func}
jax.numpy.linalg.svd
on TPUs uses a qdwh-svd solver. - {func}
jax.numpy.linalg.cond
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.pinv
on TPUs now accepts complex input. - {func}
jax.numpy.linalg.matrix_rank
on TPUs now accepts complex input. - {func}
jax.scipy.cluster.vq.vq
has been added. jax.experimental.maps.mesh
has been deleted.
Please usejax.experimental.maps.Mesh
. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh
for more information.- {func}
jax.scipy.linalg.qr
now returns a length-1 tuple rather than the raw array whenmode='r'
, in order to match the behavior ofscipy.linalg.qr
({jax-issue}#10452
) - {func}
jax.numpy.take_along_axis
now takes an optionalmode
parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passingmode="clip"
. - {func}
jax.numpy.take
now defaults tomode="fill"
, which returns invalid values (e.g., NaN) for out-of-bounds indices. - Scatter operations, such as
x.at[...].set(...)
, now have"drop"
semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct. - {func}
jax.numpy.take_along_axis
now raises aTypeError
if its indices are not of an integer type, matching the behavior of
{func}numpy.take_along_axis
. Previously non-integer indices were silently cast to integers. - {func}
jax.numpy.ravel_multi_index
now raises aTypeError
if itsdims
argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index
. Previously non-integerdims
was silently cast to integers. - {func}
jax.numpy.split
now raises aTypeError
if itsaxis
argument is not of an integer type, matching the behavior of {func}numpy.split
. Previously non-integeraxis
was silently cast to integers. - {func}
jax.numpy.indices
now raises aTypeError
if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices
. Previously non-integer dimensions were silently cast to integers. - {func}
jax.numpy.diag
now raises aTypeError
if itsk
argument is not of an integer type, matching the behavior of {func}numpy.diag
. Previously non-integerk
was silently cast to integers. - Added {func}
jax.random.orthogonal
.
- {func}
- Deprecations
- Many functions and objects available in {mod}
jax.test_util
are now deprecated and will raise a warning on import. This includescases_from_list
,check_close
,check_eq
,device_under_test
,format_shape_dtype_string
,rand_uniform
,skip_on_devices
,with_config
,xla_bridge
, and_default_tolerance
({jax-issue}#10389
). These, along with previously-deprecatedJaxTestCase
,JaxTestLoader
, andBufferDonationTestCase
, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest
, {mod}absl.testing
, {mod}numpy.testing
, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices
. Many of the deprecated utilities will still exist in {mod}jax._src.test_util
, but these are not public APIs and as such may be changed or removed without notice in future releases.
- Many functions and objects available in {mod}
Jaxlib v0.3.7
- Linux wheels are now built conforming to the
manylinux2014
standard, instead ofmanylinux2010
.
JAX release v0.3.7
- Fixed a performance problem if the indices passed to
jax.numpy.take_along_axis
were broadcasted (#10281). jax.scipy.special.expit
andjax.scipy.special.logit
now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.- The
DeviceArray.tile()
method is deprecated, because numpy arrays do not have atile()
method. As a replacement for this, use jax.numpy.tile (#10266).
JAX release v0.3.6
- Changes:
- Upgraded libtpu wheel to the fixed version. Fixes #10218.
JAX release v0.3.5
Changes
- added
jax.random.loggamma
& improved behavior ofjax.random.beta
andjax.random.dirichlet
for small parameter values (#9906). - the private
lax_numpy
submodule is no longer exposed in thejax.numpy
namespace (#10029). - added array creation routines
jax.numpy.frombuffer
,jax.numpy.fromfunction
,
andjax.numpy.fromstring
(#10049). DeviceArray.copy()
now returns aDeviceArray
rather than anp.ndarray
(#10069)- added
jax.scipy.linalg.rsf2csf
- Deprecations:
JAX release v0.3.4
Fix a bug introduced in #9923.
JAX release v0.3.3
Jax release v0.3.1
- Changes:
jax.test_util.JaxTestCase
andjax.test_util.JaxTestLoader
are now deprecated.
The suggested replacement is to useparametrized.TestCase
directly. For tests that
rely on custom asserts such asJaxTestCase.assertAllClose()
, the suggested replacement
is to use standard numpy testing utilities such asnumpy.testing.assert_allclose()
,
which work directly with JAX arrays (#9620 ).jax.test_util.JaxTestCase
now setsjax_numpy_rank_promotion='raise'
by default
(#9562 ). To recover the previous behavior, use the new
jax.test_util.with_config
decorator:@jtu.with_config(jax_numpy_rank_promotion='allow') class MyTestCase(jtu.JaxTestCase): ...
- Added
jax.scipy.linalg.schur
,jax.scipy.linalg.sqrtm
,
jax.scipy.signal.csd
,jax.scipy.signal.stft
,
jax.scipy.signal.welch
.
Jaxlib release v0.3.0
- Changes
- Bazel 5.0.0 is now required to build jaxlib.
- jaxlib version has been bumped to 0.3.0. Please see the design doc
for the explanation.
Jax release v0.3.0
- Changes
- jax version has been bumped to 0.3.0. Please see the design doc
for the explanation.
- jax version has been bumped to 0.3.0. Please see the design doc