|
1 |
| -try: |
2 |
| - import pytensor.tensor as pt |
3 |
| - from pytensor.graph.basic import Constant, Variable |
4 |
| - from pytensor.graph.fg import MissingInputError |
5 |
| - from pytensor.graph.op import Op |
6 |
| - from pytensor.gradient import grad_not_implemented |
7 |
| -except ModuleNotFoundError: |
8 |
| - import aesara.tensor as pt |
9 |
| - from aesara.graph.basic import Constant, Variable |
10 |
| - from aesara.graph.fg import MissingInputError |
11 |
| - from aesara.graph.op import Op |
12 |
| - from aesara.gradient import grad_not_implemented |
| 1 | +import pytensor.tensor as pt |
| 2 | +from pytensor.graph.basic import Constant, Variable |
| 3 | +from pytensor.graph.fg import MissingInputError |
| 4 | +from pytensor.graph.op import Op |
| 5 | +from pytensor.gradient import grad_not_implemented |
13 | 6 | import copy
|
14 | 7 | from typing import Dict, Optional, Any, Callable
|
15 | 8 |
|
@@ -52,12 +45,11 @@ def read_dict(vals, name=None):
|
52 | 45 | if isinstance(vals, tuple):
|
53 | 46 | tensor, dim_names = vals
|
54 | 47 | else:
|
55 |
| - try: |
56 |
| - tensor, dim_names = vals, pt.as_tensor_variable(vals, dtype="float64").shape.eval() |
57 |
| - except MissingInputError as e: |
| 48 | + tensor, dim_names = vals, pt.as_tensor_variable(vals, dtype="float64").type.shape |
| 49 | + if any(d is None for d in dim_names): |
58 | 50 | raise ValueError(
|
59 |
| - 'Shapes of tensors need to be statically ' |
60 |
| - 'known or given explicitly.') from e |
| 51 | + 'Shapes of tensors need to be statically known or given explicitly.' |
| 52 | + ) |
61 | 53 | if isinstance(dim_names, (str, int)):
|
62 | 54 | dim_names = (dim_names,)
|
63 | 55 | tensor = pt.as_tensor_variable(tensor, dtype="float64")
|
|
0 commit comments