Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TypeError: dot_general() got an unexpected keyword argument 'use_v2' #240

Open
mwitiderrick opened this issue Jul 1, 2022 · 0 comments
Open

Comments

@mwitiderrick
Copy link

On Kaggle notebooks saving the model generates this error
model.saved_model(x_sample, "saved-models/high-level")

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_33/3229665597.py in <module>
----> 1 model.saved_model(x_sample, "saved-models/high-level")

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in saved_model(self, inputs, path, batch_size)
    769             enable_xla=True,
    770             compile_model=False,
--> 771             save_model_options=None,
    772         )

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in convert_and_save_model(jax_fn, params, model_dir, input_signatures, shape_polymorphic_input_spec, with_gradient, enable_xla, compile_model, save_model_options)
     99         signatures[
    100             tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
--> 101         ] = tf_fun.get_concrete_function(input_signatures[0])
    102 
    103         for input_signature in input_signatures[1:]:

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in get_concrete_function(self, *args, **kwargs)
   1231   def get_concrete_function(self, *args, **kwargs):
   1232     # Implements GenericFunction.get_concrete_function.
-> 1233     concrete = self._get_concrete_function_garbage_collected(*args, **kwargs)
   1234     concrete._garbage_collector.release()  # pylint: disable=protected-access
   1235     return concrete

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _get_concrete_function_garbage_collected(self, *args, **kwargs)
   1211       if self._stateful_fn is None:
   1212         initializers = []
-> 1213         self._initialize(args, kwargs, add_initializers_to=initializers)
   1214         self._initialize_uninitialized_variables(initializers)
   1215 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in _initialize(self, args, kwds, add_initializers_to)
    758     self._concrete_stateful_fn = (
    759         self._stateful_fn._get_concrete_function_internal_garbage_collected(  # pylint: disable=protected-access
--> 760             *args, **kwds))
    761 
    762     def invalid_creator_scope(*unused_args, **unused_kwds):

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _get_concrete_function_internal_garbage_collected(self, *args, **kwargs)
   3064       args, kwargs = None, None
   3065     with self._lock:
-> 3066       graph_function, _ = self._maybe_define_function(args, kwargs)
   3067     return graph_function
   3068 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   3461 
   3462           self._function_cache.missed.add(call_context_key)
-> 3463           graph_function = self._create_graph_function(args, kwargs)
   3464           self._function_cache.primary[cache_key] = graph_function
   3465 

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   3306             arg_names=arg_names,
   3307             override_flat_arg_shapes=override_flat_arg_shapes,
-> 3308             capture_by_value=self._capture_by_value),
   3309         self._function_attributes,
   3310         function_spec=self.function_spec,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes, acd_record_initial_resource_uses)
   1005         _, original_func = tf_decorator.unwrap(python_func)
   1006 
-> 1007       func_outputs = python_func(*func_args, **func_kwargs)
   1008 
   1009       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/opt/conda/lib/python3.7/site-packages/tensorflow/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    666         # the function a weak reference to itself to avoid a reference cycle.
    667         with OptionalXlaContext(compile_with_xla):
--> 668           out = weak_wrapped_fn().__wrapped__(*args, **kwds)
    669         return out
    670 

/opt/conda/lib/python3.7/site-packages/elegy/model/utils.py in <lambda>(inputs)
     90         )
     91         tf_fun = tf.function(
---> 92             lambda inputs: tf_fn(param_vars, inputs),
     93             autograph=False,
     94             experimental_compile=compile_model,

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in converted_fun(*args, **kwargs)
    441       else:
    442         out_with_avals = _interpret_fun(flat_fun, args_flat, args_avals_flat,
--> 443                                         name_stack, fresh_constant_cache=True)
    444         outs, out_avals = util.unzip2(out_with_avals)
    445         message = ("The jax2tf-converted function does not support gradients. "

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _interpret_fun(fun, in_vals, in_avals, extra_name_stack, fresh_constant_cache)
    511         out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
    512             _call_wrapped_with_new_constant_cache(fun, in_vals,
--> 513                                                   fresh_constant_cache=fresh_constant_cache)
    514 
    515       del main

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _call_wrapped_with_new_constant_cache(fun, in_vals, fresh_constant_cache)
    530 
    531     out_vals: Sequence[Tuple[TfVal, core.ShapedArray]] = \
--> 532         fun.call_wrapped(*in_vals)
    533   finally:
    534     if prev_constant_cache is not None and not fresh_constant_cache:

/opt/conda/lib/python3.7/site-packages/jax/linear_util.py in call_wrapped(self, *args, **kwargs)
    166 
    167     try:
--> 168       ans = self.f(*args, **dict(self.params, **kwargs))
    169     except:
    170       # Some transformations yield from inside context managers, so we have to

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in fun_no_kwargs(*args_and_kwargs)
    284       kwargs = {kw: args_and_kwargs[nr_positional_args + i]
    285                 for i, kw in enumerate(kw_names)}
--> 286       return fun(*args, **kwargs)
    287 
    288     def check_arg(a):

/opt/conda/lib/python3.7/site-packages/elegy/model/model_core.py in jax_fn(flat_states, inputs)
    755 
    756             y_pred, _ = model.pred_step(
--> 757                 inputs=inputs,
    758             )
    759 

/opt/conda/lib/python3.7/site-packages/elegy/model/model.py in pred_step(self, inputs)
    196         inputs_obj = tx.Inputs.from_value(inputs)
    197 
--> 198         preds = model.module(*inputs_obj.args, **inputs_obj.kwargs)
    199 
    200         return preds, model

/opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs)
    114             @functools.wraps(cls.__call__)
    115             def new_call(self: Module, *args, **kwargs):
--> 116                 outputs = orig_call(self, *args, **kwargs)
    117 
    118                 if (

/opt/conda/lib/python3.7/site-packages/treeo/api.py in wrapper(tree, *args, **kwargs)
    518     def wrapper(tree, *args, **kwargs):
    519         with tree_m._COMPACT_CONTEXT.compact(f, tree):
--> 520             return f(tree, *args, **kwargs)
    521 
    522     wrapper._treeo_compact = True

/tmp/ipykernel_33/786988758.py in __call__(self, x)
     14         x = eg.nn.Flatten()(x)
     15         # first layers
---> 16         x = eg.nn.Linear(self.n1)(x)
     17         x = jax.nn.relu(x)
     18         # first layers

/opt/conda/lib/python3.7/site-packages/treex/module.py in new_call(self, *args, **kwargs)
    114             @functools.wraps(cls.__call__)
    115             def new_call(self: Module, *args, **kwargs):
--> 116                 outputs = orig_call(self, *args, **kwargs)
    117 
    118                 if (

/opt/conda/lib/python3.7/site-packages/treex/nn/linear.py in __call__(self, x)
    119             params["bias"] = self.bias
    120 
--> 121         output = self.module.apply({"params": params}, x)
    122         return tp.cast(jnp.ndarray, output)

    [... skipping hidden 7 frame]

/opt/conda/lib/python3.7/site-packages/flax/linen/linear.py in __call__(self, inputs)
    188     y = lax.dot_general(inputs, kernel,
    189                         (((inputs.ndim - 1,), (0,)), ((), ())),
--> 190                         precision=self.precision)
    191     if self.use_bias:
    192       bias = self.param('bias', self.bias_init, (self.features,),

    [... skipping hidden 3 frame]

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in process_primitive(self, primitive, tracers, params)
    843           val_out = invoke_impl()
    844       else:
--> 845         val_out = invoke_impl()
    846 
    847     if primitive.multiple_results:

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in invoke_impl()
    809             _in_avals=args_avals,  # type: ignore
    810             _out_aval=out_aval,
--> 811             **params)
    812       else:
    813         return impl(*args_tf, **params)

/opt/conda/lib/python3.7/site-packages/jax/experimental/jax2tf/jax2tf.py in _dot_general(lhs, rhs, dimension_numbers, precision, preferred_element_type, _in_avals, _out_aval)
   1595       precision_config_proto,
   1596       preferred_element_type=preferred_element_type,
-> 1597       use_v2=True)
   1598   if _WRAP_JAX_JIT_WITH_TF_FUNCTION:
   1599     res = tf.stop_gradient(res)  # See #7839

TypeError: dot_general() got an unexpected keyword argument 'use_v2'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant