Skip to content

Commit

Permalink
python312Packages.flax: 0.8.5 -> 0.9.0 (NixOS#342970)
Browse files Browse the repository at this point in the history
  • Loading branch information
pbsds committed Sep 20, 2024
2 parents a0d059d + 588d188 commit ffe5c6c
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 6 deletions.
16 changes: 10 additions & 6 deletions pkgs/development/python-modules/flax/default.nix
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
{
lib,
buildPythonPackage,
pythonOlder,
fetchFromGitHub,

# build-system
Expand All @@ -26,23 +25,22 @@
pytest-xdist,
pytestCheckHook,
tensorflow,
treescope,

# optional-dependencies
matplotlib,
}:

buildPythonPackage rec {
pname = "flax";
version = "0.8.5";
version = "0.9.0";
pyproject = true;

disabled = pythonOlder "3.9";

src = fetchFromGitHub {
owner = "google";
repo = "flax";
rev = "refs/tags/v${version}";
hash = "sha256-6WOFq0758gtNdrlWqSQBlKmWVIGe5e4PAaGrvHoGjr0=";
hash = "sha256-iDWuUJKO7V4QrbVsS4ALgy6fbllOC43o7W4mhjtZ9xc=";
};

build-system = [
Expand Down Expand Up @@ -75,6 +73,7 @@ buildPythonPackage rec {
pytest-xdist
pytestCheckHook
tensorflow
treescope
];

pytestFlagsArray = [
Expand All @@ -95,13 +94,18 @@ buildPythonPackage rec {
"flax/nnx/examples/*"
# See https://github.com/google/flax/issues/3232.
"tests/jax_utils_test.py"
# Requires tree
# Too old version of tensorflow:
# ModuleNotFoundError: No module named 'keras.api._v2'
"tests/tensorboard_test.py"
];

disabledTests = [
# ValueError: Checkpoint path should be absolute
"test_overwrite_checkpoints0"
# Fixed in more recent versions of jax: https://github.com/google/flax/issues/4211
# TODO: Re-enable when jax>0.4.28 will be available in nixpkgs
"test_vmap_and_cond_passthrough" # ValueError: vmap has mapped output but out_axes is None
"test_vmap_and_cond_passthrough_error" # AssertionError: "at vmap.*'broadcast'.*got axis spec ...
];

meta = {
Expand Down
65 changes: 65 additions & 0 deletions pkgs/development/python-modules/treescope/default.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
{
lib,
buildPythonPackage,
fetchFromGitHub,

# build-system
flit-core,

# dependencies
numpy,

# optional-dependencies
ipython,
jax,
palettable,

# tests
absl-py,
jaxlib,
pytestCheckHook,
torch,
}:

buildPythonPackage rec {
pname = "treescope";
version = "0.1.5";
pyproject = true;

src = fetchFromGitHub {
owner = "google-deepmind";
repo = "treescope";
rev = "refs/tags/v${version}";
hash = "sha256-+Hm60O9tEXIiE0av1O0BsOdMln4e1s7ijb3WNiQ74jE=";
};

build-system = [ flit-core ];

dependencies = [ numpy ];

optional-dependencies = {
notebook = [
ipython
jax
palettable
];
};

pythonImportsCheck = [ "treescope" ];

nativeCheckInputs = [
absl-py
jax
jaxlib
pytestCheckHook
torch
];

meta = {
description = "An interactive HTML pretty-printer for machine learning research in IPython notebooks";
homepage = "https://github.com/google-deepmind/treescope";
changelog = "https://github.com/google-deepmind/treescope/releases/tag/v${version}";
license = lib.licenses.asl20;
maintainers = with lib.maintainers; [ GaetanLepage ];
};
}
2 changes: 2 additions & 0 deletions pkgs/top-level/python-packages.nix
Original file line number Diff line number Diff line change
Expand Up @@ -15778,6 +15778,8 @@ self: super: with self; {

treeo = callPackage ../development/python-modules/treeo { };

treescope = callPackage ../development/python-modules/treescope { };

treex = callPackage ../development/python-modules/treex { };

treq = callPackage ../development/python-modules/treq { };
Expand Down

0 comments on commit ffe5c6c

Please sign in to comment.