@@ -119,9 +119,10 @@ def sqrt(self, tensor: Tensor) -> Tensor:
119119 return jnp .sqrt (tensor )
120120
121121 def convert_to_tensor (self , tensor : Tensor ) -> Tensor :
122- if (not isinstance (tensor , jnp .ndarray ) and not jnp .isscalar (tensor )):
123- raise TypeError ("Expected a `jnp.array` or scalar. Got {}" .format (
124- type (tensor )))
122+ if (not isinstance (tensor , (np .ndarray , jnp .ndarray ))
123+ and not jnp .isscalar (tensor )):
124+ raise TypeError (("Expected a `jnp.array`, `np.array` or scalar. "
125+ f"Got { type (tensor )} " ))
125126 result = jnp .asarray (tensor )
126127 return result
127128
@@ -320,7 +321,7 @@ def A(H,x):
320321 "`dtype` have to be provided" )
321322 initial_state = self .randn (shape , dtype )
322323
323- if not isinstance (initial_state , jnp .ndarray ):
324+ if not isinstance (initial_state , ( jnp .ndarray , np . ndarray ) ):
324325 raise TypeError ("Expected a `jax.array`. Got {}" .format (
325326 type (initial_state )))
326327
@@ -435,7 +436,7 @@ def A(H,x):
435436 "`dtype` have to be provided" )
436437 initial_state = self .randn (shape , dtype )
437438
438- if not isinstance (initial_state , jnp .ndarray ):
439+ if not isinstance (initial_state , ( jnp .ndarray , np . ndarray ) ):
439440 raise TypeError ("Expected a `jax.array`. Got {}" .format (
440441 type (initial_state )))
441442
@@ -555,7 +556,7 @@ def A(H,x):
555556 "`dtype` have to be provided" )
556557 initial_state = self .randn (shape , dtype )
557558
558- if not isinstance (initial_state , jnp .ndarray ):
559+ if not isinstance (initial_state , ( jnp .ndarray , np . ndarray ) ):
559560 raise TypeError ("Expected a `jax.array`. Got {}" .format (
560561 type (initial_state )))
561562 if A not in _CACHED_MATVECS :
0 commit comments