Skip to content
This repository was archived by the owner on Nov 7, 2024. It is now read-only.

Commit edf3f34

Browse files
authored
Fix build (#939)
* fix broken build * remove unneeded import * remove print * fix indent * fix tests after h5py upgrade * silence the linter * linting * more linter-silencing * fix tensordot2 test * fix string encoding * fix torch version, change pylint indent to 2 * change yapf file * nit * remove print * silence linter * add comment * fix number of epochs in the test * fix convert_to_tensor in JaxBackend * fix type tests for jax-arrays
1 parent 09181e6 commit edf3f34

File tree

10 files changed

+18
-16
lines changed

10 files changed

+18
-16
lines changed

.style.yapf

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
[style]
22
# TensorNetwork uses the yapf style
33
based_on_style = yapf
4+
indent_width = 2

requirements_travis.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
tensorflow>=2.0.0
22
pytest
3-
torch>=1.4.0
3+
torch==1.8.1 # TODO (mganahl): remove restriction once torch.tensordot bug is fixed (https://github.com/pytorch/pytorch/issues/65524)
44
jax>=0.1.68
55
jaxlib>=0.1.59
66
pylint==2.5.3

tensornetwork/backends/jax/jax_backend.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,10 @@ def sqrt(self, tensor: Tensor) -> Tensor:
119119
return jnp.sqrt(tensor)
120120

121121
def convert_to_tensor(self, tensor: Tensor) -> Tensor:
122-
if (not isinstance(tensor, jnp.ndarray) and not jnp.isscalar(tensor)):
123-
raise TypeError("Expected a `jnp.array` or scalar. Got {}".format(
124-
type(tensor)))
122+
if (not isinstance(tensor, (np.ndarray, jnp.ndarray))
123+
and not jnp.isscalar(tensor)):
124+
raise TypeError(("Expected a `jnp.array`, `np.array` or scalar. "
125+
f"Got {type(tensor)}"))
125126
result = jnp.asarray(tensor)
126127
return result
127128

@@ -320,7 +321,7 @@ def A(H,x):
320321
"`dtype` have to be provided")
321322
initial_state = self.randn(shape, dtype)
322323

323-
if not isinstance(initial_state, jnp.ndarray):
324+
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
324325
raise TypeError("Expected a `jax.array`. Got {}".format(
325326
type(initial_state)))
326327

@@ -435,7 +436,7 @@ def A(H,x):
435436
"`dtype` have to be provided")
436437
initial_state = self.randn(shape, dtype)
437438

438-
if not isinstance(initial_state, jnp.ndarray):
439+
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
439440
raise TypeError("Expected a `jax.array`. Got {}".format(
440441
type(initial_state)))
441442

@@ -555,7 +556,7 @@ def A(H,x):
555556
"`dtype` have to be provided")
556557
initial_state = self.randn(shape, dtype)
557558

558-
if not isinstance(initial_state, jnp.ndarray):
559+
if not isinstance(initial_state, (jnp.ndarray, np.ndarray)):
559560
raise TypeError("Expected a `jax.array`. Got {}".format(
560561
type(initial_state)))
561562
if A not in _CACHED_MATVECS:

tensornetwork/tn_keras/condenser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
# pytype: disable=module-attr
14-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
14+
@tf.keras.utils.register_keras_serializable(package='tensornetwork')# pylint: disable=no-member
1515
# pytype: enable=module-attr
1616
class DenseCondenser(Layer):
1717
"""Condenser TN layer. Greatly reduces dimensionality of input.

tensornetwork/tn_keras/conv2d_mpo.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import math
1111

1212
# pytype: disable=module-attr
13-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
13+
@tf.keras.utils.register_keras_serializable(package='tensornetwork')# pylint: disable=no-member
1414
# pytype: enable=module-attr
1515
class Conv2DMPO(Layer):
1616
"""2D Convolutional Matrix Product Operator (MPO) TN layer.
@@ -197,7 +197,7 @@ def is_perfect_root(n, n_nodes):
197197
else:
198198
self.use_bias = None
199199

200-
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: #pylint: disable=arguments-differ
200+
def call(self, inputs: tf.Tensor) -> tf.Tensor: #pylint: disable=arguments-differ
201201

202202
tn_nodes = [tn.Node(n, backend='tensorflow') for n in self.nodes]
203203
for i in range(len(tn_nodes) - 1):

tensornetwork/tn_keras/dense.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010

1111
# pytype: disable=module-attr
12-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
12+
@tf.keras.utils.register_keras_serializable(package='tensornetwork')# pylint: disable=no-member
1313
# pytype: enable=module-attr
1414
class DenseDecomp(Layer):
1515
"""TN layer comparable to Dense that carries out matrix multiplication

tensornetwork/tn_keras/entangler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
# pytype: disable=module-attr
14-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
14+
@tf.keras.utils.register_keras_serializable(package='tensornetwork')# pylint: disable=no-member
1515
# pytype: enable=module-attr
1616
class DenseEntangler(Layer):
1717
"""Entangler TN layer. Allows for very large hidden layers.

tensornetwork/tn_keras/expander.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
# pytype: disable=module-attr
14-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
14+
@tf.keras.utils.register_keras_serializable(package='tensornetwork') # pylint: disable=no-member
1515
# pytype: enable=module-attr
1616
class DenseExpander(Layer):
1717
"""Expander TN layer. Greatly expands dimensionality of input.

tensornetwork/tn_keras/mpo.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
# pytype: disable=module-attr
14-
@tf.keras.utils.register_keras_serializable(package='tensornetwork')
14+
@tf.keras.utils.register_keras_serializable(package='tensornetwork')# pylint: disable=no-member
1515
# pytype: enable=module-attr
1616
class DenseMPO(Layer):
1717
"""Matrix Product Operator (MPO) TN layer.

tensornetwork/tn_keras/test_layer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ def test_train(dummy_data, make_model):
113113
loss='binary_crossentropy',
114114
metrics=['accuracy'])
115115

116-
# Train the model for 10 epochs
117-
history = model.fit(data, labels, epochs=50, batch_size=64)
116+
# Train the model for 40 epochs
117+
history = model.fit(data, labels, epochs=40, batch_size=64)
118118

119119
# Check that loss decreases and accuracy increases
120120
assert history.history['loss'][0] > history.history['loss'][-1]

0 commit comments

Comments
 (0)