We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
On Kaggle notebooks saving the model generates this error model.saved_model(x_sample, "saved-models/high-level")
model.saved_model(x_sample, "saved-models/high-level")
--------------------------------------------------------------------------- TypeError Traceback (most recent call last) /tmp/ipykernel_33/3229665597.py in <module> ----> 1 model.saved_model(x_sample, "saved-models/high-level") /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size) 769 enable_xla=True, 770 compile_model=False, --> 771 save_model_options=None, 772 ) /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options) 99 signatures[ 100 tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY --> 101 ] = tf_fun.get_concrete_function(input_signatures[0]) 102 103 for input_signature in input_signatures[1:]: /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs) 1231 def get_concrete_function(self, *args, **kwargs): 1232 # Implements GenericFunction.get_concrete_function. -> 1233 concrete = self._get_concrete_function_garbage_collected(*args, **kwargs) 1234 concrete._garbage_collector.release() # pylint: disable=protected-access 1235 return concrete /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs) 1211 if self._stateful_fn is None: 1212 initializers = [] -> 1213 self._initialize(args, kwargs, add_initializers_to=initializers) 1214 self._initialize_uninitialized_variables(initializers) 1215 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to) 758 self._concrete_stateful_fn = ( 759 self._stateful_fn._get_concrete_function_internal_garbage_collected( # pylint: disable=protected-access --> 760 *args, **kwds)) 761 762 def invalid_creator_scope(*unused_args, **unused_kwds): /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs) 3064 args, kwargs = None, None 3065 with self._lock: -> 3066 graph_function, _ = self._maybe_define_function(args, kwargs) 3067 return graph_function 3068 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs) 3461 3462 self._function_cache.missed.add(call_context_key) -> 3463 graph_function = self._create_graph_function(args, kwargs) 3464 self._function_cache.primary[cache_key] = graph_function 3465 /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes) 3306 arg_names=arg_names, 3307 override_flat_arg_shapes=override_flat_arg_shapes, -> 3308 capture_by_value=self._capture_by_value), 3309 self._function_attributes, 3310 function_spec=self.function_spec, /opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses) 1005 _, original_func = tf_decorator.unwrap(python_func) 1006 -> 1007 func_outputs = python_func(*func_args, **func_kwargs) 1008 1009 # invariant: `func_outputs` contains only Tensors, CompositeTensors, /opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds) 666 # the function a weak reference to itself to avoid a reference cycle. 667 with OptionalXlaContext(compile_with_xla): --> 668 out = weak_wrapped_fn().__wrapped__(*args, **kwds) 669 return out 670 /opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs) 90 ) 91 tf_fun = tf.function( ---> 92 lambda inputs: tf_fn(param_vars, inputs), 93 autograph=False, 94 experimental_compile=compile_model, /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs) 441 else: 442 out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat, --> 443 name_stack, fresh_constant_cache=True) 444 outs, out_avals = util.unzip2(out_with_avals) 445 message = ("The jax2tf-converted function does not support gradients. " /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache) 511 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ 512 _call_wrapped_with_new_constant_cache(fun, in_vals, --> 513 fresh_constant_cache=fresh_constant_cache) 514 515 del main /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache) 530 531 out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \ --> 532 fun.call_wrapped(*in_vals) 533 finally: 534 if prev_constant_cache is not None and not fresh_constant_cache: /opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs) 166 167 try: --> 168 ans = self.f(*args, **dict(self.params, **kwargs)) 169 except: 170 # Some transformations yield from inside context managers, so we have to /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs) 284 kwargs = {kw: args_and_kwargs[nr_positional_args + i] 285 for i, kw in enumerate(kw_names)} --> 286 return fun(*args, **kwargs) 287 288 def check_arg(a): /opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs) 755 756 y_pred, _ = model.pred_step( --> 757 inputs=inputs, 758 ) 759 /opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs) 196 inputs_obj = tx.Inputs.from_value(inputs) 197 --> 198 preds = model.module(*inputs_obj.args, **inputs_obj.kwargs) 199 200 return preds, model /opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs) 114 @functools.wraps(cls.__call__) 115 def new_call(self: Module, *args, **kwargs): --> 116 outputs = orig_call(self, *args, **kwargs) 117 118 if ( /opt/conda/lib/python3.7/site-packages/treeo/api.py in wrapper(tree, *args, **kwargs) 518 def wrapper(tree, *args, **kwargs): 519 with tree_m._COMPACT_CONTEXT.compact(f, tree): --> 520 return f(tree, *args, **kwargs) 521 522 wrapper._treeo_compact = True /tmp/ipykernel_33/786988758.py in __call__(self, x) 14 x = eg.nn.Flatten()(x) 15 # first layers ---> 16 x = eg.nn.Linear(self.n1)(x) 17 x = jax.nn.relu(x) 18 # first layers /opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs) 114 @functools.wraps(cls.__call__) 115 def new_call(self: Module, *args, **kwargs): --> 116 outputs = orig_call(self, *args, **kwargs) 117 118 if ( /opt/conda/lib/python3.7/site-packages/treex/nn/linear.py in __call__(self, x) 119 params["bias"] = self.bias 120 --> 121 output = self.module.apply({"params": params}, x) 122 return tp.cast(jnp.ndarray, output) [... skipping hidden 7 frame] /opt/conda/lib/python3.7/site-packages/flax/linen/linear.py in __call__(self, inputs) 188 y = lax.dot_general(inputs, kernel, 189 (((inputs.ndim - 1,), (0,)), ((), ())), --> 190 precision=self.precision) 191 if self.use_bias: 192 bias = self.param('bias', self.bias_init, (self.features,), [... skipping hidden 3 frame] /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in process_primitive(self, primitive, tracers, params) 843 val_out = invoke_impl() 844 else: --> 845 val_out = invoke_impl() 846 847 if primitive.multiple_results: /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in invoke_impl() 809 _in_avals=args_avals, # type: ignore 810 _out_aval=out_aval, --> 811 **params) 812 else: 813 return impl(*args_tf, **params) /opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type, _in_avals, _out_aval) 1595 precision_config_proto, 1596 preferred_element_type=preferred_element_type, -> 1597 use_v2=True) 1598 if _WRAP_JAX_JIT_WITH_TF_FUNCTION: 1599 res = tf.stop_gradient(res) # See #7839 TypeError: dot_general() got an unexpected keyword argument 'use_v2'
The text was updated successfully, but these errors were encountered:
No branches or pull requests
On Kaggle notebooks saving the model generates this error
model.saved_model(x_sample, "saved-models/high-level")
The text was updated successfully, but these errors were encountered: