Skip to content

JAX v0.4.36

Choose a tag to compare
@hawkinsp hawkinsp released this 05 Dec 23:33
· 1500 commits to main since this release
  • Breaking Changes

    • This release lands "stackless", an internal change to JAX's tracing
      machinery. We made trace dispatch purely a function of context rather than a
      function of both context and data. This let us delete a lot of machinery for
      managing data-dependent tracing: levels, sublevels, post_process_call,
      new_base_main, custom_bind, and so on. The change should only affect
      users that use JAX internals.

      If you do use JAX internals then you may need to
      update your code (see
      for clues about how to do this). There might also be version skew
      issues with JAX libraries that do this. If you find this change breaks your
      non-JAX-internals-using code then try the
      config.jax_data_dependent_tracing_fallback flag as a workaround, and if
      you need help updating your code then please file a bug.

    • jax.experimental.jax2tf.convert with native_serialization=False
      or with enable_xla=False have been deprecated since July 2024, with
      JAX version 0.4.31. Now we removed support for these use cases. jax2tf
      with native serialization will still be supported.

    • In jax.interpreters.xla, the xb, xc, and xe symbols have been removed
      after being deprecated in JAX v0.4.31. Instead use xb = jax.lib.xla_bridge,
      xc = jax.lib.xla_client, and xe = jax.lib.xla_extension.

    • The deprecated module jax.experimental.export has been removed. It was replaced
      by jax.export in JAX v0.4.30. See the migration guide
      for information on migrating to the new API.

    • The initial argument to jax.nn.softmax and jax.nn.log_softmax
      has been removed, after being deprecated in v0.4.27.

    • Calling np.asarray on typed PRNG keys (i.e. keys produced by jax.random.key)
      now raises an error. Previously, this returned a scalar object array.

    • The following deprecated methods and functions in jax.export have
      been removed:

      • jax.export.DisabledSafetyCheck.shape_assertions: it had no effect
      • jax.export.Exported.lowering_platforms: use platforms.
      • jax.export.Exported.mlir_module_serialization_version:
        use calling_convention_version.
      • jax.export.Exported.uses_shape_polymorphism:
        use uses_global_constants.
      • the lowering_platforms kwarg for jax.export.export: use
        platforms instead.
    • The kwargs symbolic_scope and symbolic_constraints from
      jax.export.symbolic_args_specs have been removed. They were
      deprecated in June 2024. Use scope and constraints instead.

    • Hashing of tracers, which has been deprecated since version 0.4.30, now
      results in a TypeError.

    • Refactor: JAX build CLI (build/ now uses a subcommand structure and
      replaces previous usage. Run python build/ --help for
      more details. Brief overview of the new subcommand options:

      • build: Builds JAX wheel packages. For e.g., python build/ build --wheels=jaxlib,jax-cuda-pjrt
      • requirements_update: Updates requirements_lock.txt files.
    • jax.scipy.linalg.toeplitz now does implicit batching on multi-dimensional
      inputs. To recover the previous behavior, you can call jax.numpy.ravel
      on the function inputs.

    • jax.scipy.special.gamma and jax.scipy.special.gammasgn now
      return NaN for negative integer inputs, to match the behavior of SciPy from

    • jax.clear_backends was removed after being deprecated in v0.4.26.

    • We removed the custom call "__gpu$xla.gpu.triton" from the list of custom
      call that we guarantee export stability. This is because this custom call
      relies on Triton IR, which is not guaranteed to be stable. If you need
      to export code that uses this custom call, you can use the disabled_checks
      parameter. See more details in the documentation.

  • New Features

    • jax.jit got a new compiler_options: dict[str, Any] argument, for
      passing compilation options to XLA. For the moment it's undocumented and
      may be in flux.
    • jax.tree_util.register_dataclass now allows metadata fields to be
      declared inline via dataclasses.field. See the function documentation
      for examples.
    • Added jax.numpy.put_along_axis.
    • jax.lax.linalg.eig and the related jax.numpy functions
      (jax.numpy.linalg.eig and jax.numpy.linalg.eigvals) are now
      supported on GPU. See #24663 for more details.
    • Added two new configuration flags, jax_exec_time_optimization_effort and jax_memory_fitting_effort, to control the amount of effort the compiler spends minimizing execution time and memory usage, respectively. Valid values are between -1.0 and 1.0, default is 0.0.
  • Bug fixes

    • Fixed a bug where the GPU implementations of LU and QR decomposition would
      result in an indexing overflow for batch sizes close to int32 max. See
      #24843 for more details.
  • Deprecations

    • jax.lib.xla_extension.ArrayImpl and jax.lib.xla_client.ArrayImpl are deprecated;
      use jax.Array instead.
    • jax.lib.xla_extension.XlaRuntimeError is deprecated; use jax.errors.JaxRuntimeError