Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: '_DimPolynomial' object cannot be interpreted as an integer #245

Open
mwitiderrick opened this issue Aug 20, 2022 · 0 comments
Open

Comments

@mwitiderrick
Copy link

model.saved_model(np.array(X_test_padded[0]), "saved-models/high-level")

Error when converting LSTM to TensorFlow

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_127/2114008176.py in <module>
----> 1 model.saved_model(np.array(X_test_padded[0]), "saved-models/high-level")

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size)
    769             enable_xla=True,
    770             compile_model=False,
--> 771             save_model_options=None,
    772         )

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options)
     99         signatures[
    100             tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
--> 101         ] = tf_fun.get_concrete_function(input_signatures[0])
    102 
    103         for input_signature in input_signatures[1:]:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1231   def get_concrete_function(self, *args, **kwargs):
   1232     # Implements GenericFunction.get_concrete_function.
-> 1233     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1234     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1235     return concrete

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1211       if self._stateful_fn is None:
   1212         initializers = []
-> 1213         self._initialize(args, kwargs, add_initializers_to=initializers)
   1214         self._initialize_uninitialized_variables(initializers)
   1215 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    761 
    762     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3306             arg_names=arg_names,
   3307             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308             capture_by_value=self._capture_by_value),
   3309         self._function_attributes,
   3310         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs)
     90         )
     91         tf_fun = tf.function(
---> 92             lambda inputs: tf_fn(param_vars, inputs),
     93             autograph=False,
     94             experimental_compile=compile_model,

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs)
    460       else:
    461         out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
--> 462                                         name_stack, fresh_constant_cache=True)
    463         outs, out_avals = util.unzip2(out_with_avals)
    464         message = ("The jax2tf-converted function does not support gradients. "

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache)
    534           out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
    535               _call_wrapped_with_new_constant_cache(fun, in_vals,
--> 536                                                     fresh_constant_cache=fresh_constant_cache)
    537 
    538         del main

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache)
    687 
    688     out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
--> 689         fun.call_wrapped(*in_vals)
    690   finally:
    691     if prev_constant_cache is not None and not fresh_constant_cache:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    166 
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:
    170       # Some transformations yield from inside context managers, so we have to

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs)
    300       kwargs = {kw: args_and_kwargs[nr_positional_args + i]
    301                 for i, kw in enumerate(kw_names)}
--> 302       return fun(*args, **kwargs)
    303 
    304     def check_arg(a):

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs)
    755 
    756             y_pred, _ = model.pred_step(
--> 757                 inputs=inputs,
    758             )
    759 

/opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs)
    196         inputs_obj = tx.Inputs.from_value(inputs)
    197 
--> 198         preds = model.module(*inputs_obj.args, **inputs_obj.kwargs)
    199 
    200         return preds, model

/opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs)
    114             @functools.wraps(cls.__call__)
    115             def new_call(self: Module, *args, **kwargs):
--> 116                 outputs = orig_call(self, *args, **kwargs)
    117 
    118                 if (

/opt/conda/lib/python3.7/site-packages/treex/nn/flax_module.py in __call__(self, *args, **kwargs)
     96             rngs=rngs,
     97             method=method,
---> 98             **kwargs,
     99         )
    100         variables.update(updates.unfreeze())

    [... skipping hidden 5 frame]

/opt/conda/lib/python3.7/contextlib.py in inner(*args, **kwds)
     72         def inner(*args, **kwds):
     73             with self._recreate_cm():
---> 74                 return func(*args, **kwds)
     75         return inner
     76 

    [... skipping hidden 15 frame]

/tmp/ipykernel_127/3182282043.py in __call__(self, x_batch)
     23         x = self.embedding(x_batch)
     24 
---> 25         carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128)
     26         (carry, hidden), x = self.lstm1((carry, hidden), x)
     27 

TypeError: '_DimPolynomial' object cannot be interpreted as an integer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant