diff --git a/elegy/__init__.py b/elegy/__init__.py index a61afa35..65dfb9dc 100644 --- a/elegy/__init__.py +++ b/elegy/__init__.py @@ -1,6 +1,6 @@ # isort:skip_file -__version__ = "0.8.2" +__version__ = "0.8.3" from treex import * diff --git a/elegy/model/model.py b/elegy/model/model.py index d1b2cc1d..3c3dc796 100644 --- a/elegy/model/model.py +++ b/elegy/model/model.py @@ -1,16 +1,11 @@ import typing as tp -from io import StringIO -import einops import flax import jax -import jax.experimental.host_callback as hcb import jax.numpy as jnp import numpy as np import treex as tx -from jax._src.tree_util import tree_map from optax import GradientTransformation -from treex.nn.haiku_module import HaikuModule from elegy import types, utils from elegy.model.model_base import ModelBase @@ -28,9 +23,11 @@ import haiku as hk TransformedWithState = hk.TransformedWithState -except ImportError: + HaikuModule = tx.HaikuModule +except (ImportError, ModuleNotFoundError): hk = None TransformedWithState = type(None) + HaikuModule = tp.cast(tp.Any, None) class Model(tp.Generic[U], ModelBase): @@ -62,7 +59,7 @@ def __init__( @tp.overload def __init__( - self: "Model[tx.HaikuModule]", + self: "Model[HaikuModule]", module: hk.TransformedWithState, loss: tp.Any = None, metrics: tp.Any = None, @@ -134,7 +131,7 @@ def __init__( elif TransformedWithState is not None and isinstance( module, TransformedWithState ): - self.module = tx.HaikuModule(module) + self.module = HaikuModule(module) else: self.module = module diff --git a/pyproject.toml b/pyproject.toml index 1fe4f0d1..669e3b01 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,7 @@ authors = ["Cristian Garcia ", "Carlos Alvarez ", "David Cardozo ", "Sebastian Arango"] -version = "0.8.2" +version = "0.8.3" license = "APACHE" readme = "README.md" repository = "https://github.com/poets-ai/elegy"