We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
model.saved_model(np.array(X_test_padded[0]), "saved-models/high-level")
Error when converting LSTM to TensorFlow
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_127/2114008176.py in <module> ----> 1 model.saved_model(np.array(X_test_padded[0]), "saved-models/high-level") /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size) 769 enable_xla=True, 770 compile_model=False, --> 771 save_model_options=None, 772 ) /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options) 99 signatures[ 100 tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY --> 101 ] = tf_fun.get_concrete_function(input_signatures[0]) 102 103 for input_signature in input_signatures[1:]: /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs) 1231 def get_concrete_function(self, *args, **kwargs): 1232 # Implements GenericFunction.get_concrete_function. -> 1233 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1234 concrete._garbage_collector.release() # pylint: disable=protected-access 1235 return concrete /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs) 1211 if self._stateful_fn is None: 1212 initializers = [] -> 1213 self._initialize(args, kwargs, add_initializers_to=initializers) 1214 self._initialize_uninitialized_variables(initializers) 1215 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to) 758 self._concrete_stateful_fn = ( 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access --> 760 *args, **kwds)) 761 762 def invalid_creator_scope(*unused_args, **unused_kwds): /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 3064 args, kwargs = None, None 3065 with self._lock: -> 3066 graph_function, _ = self._maybe_define_function(args, kwargs) 3067 return graph_function 3068 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3461 3462 self._function_cache.missed.add(call_context_key) -> 3463 graph_function = self._create_graph_function(args, kwargs) 3464 self._function_cache.primary[cache_key] = graph_function 3465 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3306 arg_names=arg_names, 3307 override_flat_arg_shapes=override_flat_arg_shapes, -> 3308 capture_by_value=self._capture_by_value), 3309 self._function_attributes, 3310 function_spec=self.function_spec, /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses) 1005 _, original_func = tf_decorator.unwrap(python_func) 1006 -> 1007 func_outputs = python_func(*func_args, **func_kwargs) 1008 1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors, /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds) 666 # the function a weak reference to itself to avoid a reference cycle. 667 with OptionalXlaContext(compile_with_xla): --> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 669 return out 670 /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs) 90 ) 91 tf_fun = tf.function( ---> 92 lambda inputs: tf_fn(param_vars, inputs), 93 autograph=False, 94 experimental_compile=compile_model, /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs) 460 else: 461 out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat, --> 462 name_stack, fresh_constant_cache=True) 463 outs, out_avals = util.unzip2(out_with_avals) 464 message = ("The jax2tf-converted function does not support gradients. " /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache) 534 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ 535 _call_wrapped_with_new_constant_cache(fun, in_vals, --> 536 fresh_constant_cache=fresh_constant_cache) 537 538 del main /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache) 687 688 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ --> 689 fun.call_wrapped(*in_vals) 690 finally: 691 if prev_constant_cache is not None and not fresh_constant_cache: /opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 166 167 try: --> 168 ans = self.f(*args, **dict(self.params, **kwargs)) 169 except: 170 # Some transformations yield from inside context managers, so we have to /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs) 300 kwargs = {kw: args_and_kwargs[nr_positional_args + i] 301 for i, kw in enumerate(kw_names)} --> 302 return fun(*args, **kwargs) 303 304 def check_arg(a): /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs) 755 756 y_pred, _ = model.pred_step( --> 757 inputs=inputs, 758 ) 759 /opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs) 196 inputs_obj = tx.Inputs.from_value(inputs) 197 --> 198 preds = model.module(*inputs_obj.args, **inputs_obj.kwargs) 199 200 return preds, model /opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs) 114 @functools.wraps(cls.__call__) 115 def new_call(self: Module, *args, **kwargs): --> 116 outputs = orig_call(self, *args, **kwargs) 117 118 if ( /opt/conda/lib/python3.7/site-packages/treex/nn/flax_module.py in __call__(self, *args, **kwargs) 96 rngs=rngs, 97 method=method, ---> 98 **kwargs, 99 ) 100 variables.update(updates.unfreeze()) [... skipping hidden 5 frame] /opt/conda/lib/python3.7/contextlib.py in inner(*args, **kwds) 72 def inner(*args, **kwds): 73 with self._recreate_cm(): ---> 74 return func(*args, **kwds) 75 return inner 76 [... skipping hidden 15 frame] /tmp/ipykernel_127/3182282043.py in __call__(self, x_batch) 23 x = self.embedding(x_batch) 24 ---> 25 carry, hidden = nn.OptimizedLSTMCell.initialize_carry(jax.random.PRNGKey(0), batch_dims=(len(x_batch),), size=128) 26 (carry, hidden), x = self.lstm1((carry, hidden), x) 27 TypeError: '_DimPolynomial' object cannot be interpreted as an integer
The text was updated successfully, but these errors were encountered:
No branches or pull requests
Error when converting LSTM to TensorFlow
The text was updated successfully, but these errors were encountered: