Skip to content

Commit

Permalink
Fix import error when Haiku is not installed (#213)
Browse files Browse the repository at this point in the history
* fix haiku import issue

* 0.8.3

* fix HaikuModule
  • Loading branch information
cgarciae committed Dec 13, 2021
1 parent 75aa862 commit 227f0af
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 10 deletions.
2 changes: 1 addition & 1 deletion elegy/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# isort:skip_file

__version__ = "0.8.2"
__version__ = "0.8.3"

from treex import *

Expand Down
13 changes: 5 additions & 8 deletions elegy/model/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,11 @@
import typing as tp
from io import StringIO

import einops
import flax
import jax
import jax.experimental.host_callback as hcb
import jax.numpy as jnp
import numpy as np
import treex as tx
from jax._src.tree_util import tree_map
from optax import GradientTransformation
from treex.nn.haiku_module import HaikuModule

from elegy import types, utils
from elegy.model.model_base import ModelBase
Expand All @@ -28,9 +23,11 @@
import haiku as hk

TransformedWithState = hk.TransformedWithState
except ImportError:
HaikuModule = tx.HaikuModule
except (ImportError, ModuleNotFoundError):
hk = None
TransformedWithState = type(None)
HaikuModule = tp.cast(tp.Any, None)


class Model(tp.Generic[U], ModelBase):
Expand Down Expand Up @@ -62,7 +59,7 @@ def __init__(

@tp.overload
def __init__(
self: "Model[tx.HaikuModule]",
self: "Model[HaikuModule]",
module: hk.TransformedWithState,
loss: tp.Any = None,
metrics: tp.Any = None,
Expand Down Expand Up @@ -134,7 +131,7 @@ def __init__(
elif TransformedWithState is not None and isinstance(
module, TransformedWithState
):
self.module = tx.HaikuModule(module)
self.module = HaikuModule(module)
else:
self.module = module

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ authors = ["Cristian Garcia <[email protected]>",
"Carlos Alvarez <[email protected]>",
"David Cardozo <[email protected]>",
"Sebastian Arango"]
version = "0.8.2"
version = "0.8.3"
license = "APACHE"
readme = "README.md"
repository = "https://github.com/poets-ai/elegy"
Expand Down

0 comments on commit 227f0af

Please sign in to comment.