Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend unit tests #1967

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 138 additions & 0 deletions onnxscript/_internal/param_manipulation_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,144 @@
allow_extra_kwargs=allow_extra_kwargs,
)

def test_tag_arguments_with_extra_kwargs_not_allowed(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True),
values.ParamSchema(name="b", type=int, is_input=False),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
args = (TEST_INPUT,)
kwargs = {"b": 42, "extra": 100}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
with self.assertRaises(TypeError):
_, _ = param_manipulation.tag_arguments_with_param_schemas(
param_schemas, args, kwargs, allow_extra_kwargs=False
)


def test_turn_to_kwargs_with_variadic_inputs(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True),
values.ParamSchema(name="b", type=int, is_input=False),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
inputs = [TEST_INPUT, TEST_INPUT, TEST_INPUT]
attributes = {"b": 42}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
expected_attributes = {
"a": [TEST_INPUT, TEST_INPUT, TEST_INPUT],
"b": 42,
}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
result = param_manipulation.turn_to_kwargs_to_avoid_ordering(
param_schemas, inputs, attributes
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
self.assertEqual(result, expected_attributes)


def test_tag_arguments_with_variadic_inputs(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True),
values.ParamSchema(name="b", type=int, is_input=False),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
args = (TEST_INPUT, TEST_INPUT, TEST_INPUT)
kwargs = {"b": 42}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
expected_tagged_args = [(TEST_INPUT, param_schemas[0]), (TEST_INPUT, param_schemas[0]), (TEST_INPUT, param_schemas[0])]
expected_tagged_kwargs = {"b": (42, param_schemas[1])}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
param_schemas, args, kwargs
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
self.assertEqual(tagged_args, expected_tagged_args)
self.assertEqual(tagged_kwargs, expected_tagged_kwargs)


def test_turn_to_kwargs_to_avoid_ordering(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True),
values.ParamSchema(name="b", type=int, is_input=True),
values.ParamSchema(name="c", type=float, is_input=False, default=0.0),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
inputs = [TEST_INPUT, 42]
attributes = {"c": 0.0}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
expected_attributes = {
"a": TEST_INPUT,
"b": 42,
"c": 0.0,
}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
result = param_manipulation.turn_to_kwargs_to_avoid_ordering(
param_schemas, inputs, attributes
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
self.assertEqual(result, expected_attributes)


def test_tag_arguments_with_param_schemas(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True),
values.ParamSchema(name="b", type=int, is_input=False, default=100),
values.ParamSchema(name="c", type=float, is_input=False, default=0.0),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
args = (TEST_INPUT,)
kwargs = {"b": 42}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
expected_tagged_args = [(TEST_INPUT, param_schemas[0])]
expected_tagged_kwargs = {
"b": (42, param_schemas[1]),
"c": (0.0, param_schemas[2]),
}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
tagged_args, tagged_kwargs = param_manipulation.tag_arguments_with_param_schemas(
param_schemas, args, kwargs
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
self.assertEqual(tagged_args, expected_tagged_args)
self.assertEqual(tagged_kwargs, expected_tagged_kwargs)


def test_required_input_not_provided(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True, required=True),
values.ParamSchema(name="b", type=int, is_input=False, default=100),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
args = ()
kwargs = {"b": 42}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
with self.assertRaises(TypeError):
_, _ = param_manipulation.tag_arguments_with_param_schemas(
param_schemas, args, kwargs
)


def test_variadic_inputs(self):
param_schemas = (
values.ParamSchema(name="a", type=INT64, is_input=True, is_variadic_input=True),
values.ParamSchema(name="b", type=int, is_input=False),
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
args = (TEST_INPUT, TEST_INPUT, TEST_INPUT)
kwargs = {"b": 42}

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
expected_inputs = [TEST_INPUT, TEST_INPUT, TEST_INPUT]
expected_attributes = collections.OrderedDict([("b", 42)])

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
inputs, attributes = param_manipulation.separate_input_attributes_from_arguments(
param_schemas, args, kwargs
)

Check warning

Code scanning / lintrunner

RUFF/W293 Warning

Check warning

Code scanning / lintrunner

EDITORCONFIG-CHECKER/editorconfig Warning

Trailing whitespace
self.assertEqual(inputs, expected_inputs)
self.assertEqual(attributes, expected_attributes)



if __name__ == "__main__":
unittest.main()
84 changes: 84 additions & 0 deletions onnxscript/_legacy_ir/visitor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,90 @@ def test_function_optional_input_is_recorded_by_shape_env(self):
model_visitor.function_shape_env.lookup(model.functions[0], "optional_z")
)

def test_proto_visitor_enter_exit_function_scope(self):
function_proto = onnx.FunctionProto()
visitor_instance = visitor.ProtoVisitor()
visitor_instance.enter_function_scope(function_proto)
self.assertIsNotNone(visitor_instance.scopes.current_scope().current_function_scope())
visitor_instance.exit_function_scope(function_proto)
self.assertIsNone(visitor_instance.scopes.current_scope().current_function_scope())


def test_proto_visitor_missing_input_types(self):
node_proto = onnx.helper.make_node(
'Add',
inputs=['A', 'B'],
outputs=['C']
)
visitor_instance = visitor.ProtoVisitor(do_shape_inference=True)
visitor_instance.scopes.enter_graph_scope(onnx.GraphProto())
visitor_instance.bind('A', visitor.ir.Value(name='A', type=onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [1])))
visitor_instance.bind('B', visitor.ir.Value(name='B', type=onnx.helper.make_tensor_type_proto(onnx.TensorProto.FLOAT, [1])))
visitor_instance.process_node(node_proto)
output_value = visitor_instance.lookup('C')
self.assertIsNone(output_value.type)


def test_save_to_value_info_with_overload(self):
shape_env = visitor.FunctionShapeEnv()
value_info = onnx.helper.make_tensor_value_info('custom::function::overload/x', onnx.TensorProto.FLOAT, [1])
with self.assertRaises(NotImplementedError):
shape_env.save_to_value_info(visitor.ir.Value(name='x', type=value_info.type), 'custom', 'function', 'overload')


def test_save_to_model_proto_with_function_id_and_value_info(self):
model_proto = onnx.ModelProto()
model_proto.graph.value_info.extend([
onnx.helper.make_tensor_value_info('custom::function/x', onnx.TensorProto.FLOAT, [1])
])
shape_env = visitor.FunctionShapeEnv()
shape_env.load_from_model_proto(model_proto)
shape_env.save_to_model_proto(model_proto)
self.assertEqual(len(model_proto.graph.value_info), 2)


def test_subscope_bind_and_lookup_ref_attribute(self):
graph_proto = onnx.GraphProto()
subscope = visitor.SubScope(graph_proto)
attr_proto = onnx.AttributeProto()
attr_proto.name = "attr1"
subscope.bind_ref_attribute("attr1", attr_proto)
self.assertEqual(subscope.lookup_ref_attribute("attr1"), attr_proto)


def test_scope_bind_empty_name(self):
scope = visitor.Scope()
scope.enter_sub_scope(onnx.GraphProto())
with self.assertRaises(ValueError):
scope.bind("", visitor.ir.Value(name="value"))


def test_load_from_value_info_with_function_id(self):
value_info = onnx.helper.make_tensor_value_info('custom::function/x', onnx.TensorProto.FLOAT, [1])
shape_env = visitor.FunctionShapeEnv()
shape_env.load_from_value_info(value_info)
self.assertEqual(len(shape_env._function_values), 1)


def test_scope_bind_none_value(self):
scope = visitor.Scope()
scope.enter_sub_scope(onnx.GraphProto())
with self.assertRaises(ValueError):
scope.bind("test_name", None)


def test_load_from_value_info_with_none_function_id(self):
value_info = onnx.helper.make_tensor_value_info('x', onnx.TensorProto.FLOAT, [1])
shape_env = visitor.FunctionShapeEnv()
shape_env.load_from_value_info(value_info)
self.assertEqual(len(shape_env._function_values), 0)


def test_override_inferred_value_type_with_none_values(self):
result = visitor._override_inferred_value_type_with_symbolic_value_type(None, None)
self.assertIsNone(result)



if __name__ == "__main__":
unittest.main()
7 changes: 7 additions & 0 deletions onnxscript/backend/onnx_backend_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from onnxscript.backend import onnx_backend

import numpy as np

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "import numpy as np" should be placed before "from onnxscript.backend import onnx_backend" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

def load_function(obj):
return ort.InferenceSession(obj.SerializeToString(), providers=("CPUExecutionProvider",))
Expand Down Expand Up @@ -41,6 +42,12 @@
done += 1
self.assertEqual(done, 1)

def test_assert_almost_equal_string_with_floats(self):
expected = np.array([1.0, 2.0, 3.0])
value = np.array([1.0, 2.0, 3.0])
onnx_backend.assert_almost_equal_string(expected, value)



if __name__ == "__main__":
unittest.main(verbosity=2)
76 changes: 76 additions & 0 deletions onnxscript/evaluator_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from onnxscript.onnx_opset import opset17 as op
from onnxscript.onnx_types import FLOAT

from onnxscript import tensor
import onnx

Check notice

Code scanning / lintrunner

PYLINT/C0411 Note

third party import "import onnx" should be placed before "from onnxscript import evaluator, graph, script" (wrong-import-order)
See wrong-import-order. To disable, use # pylint: disable=wrong-import-order

class EvaluatorTest(unittest.TestCase):
def test_evaluator(self):
Expand Down Expand Up @@ -62,5 +64,79 @@
_ = test_function(x, unknown=42) # pylint: disable=unexpected-keyword-arg



def test_adapt_to_eager_mode_list_of_numpy_arrays(self):
inputs = [np.array([1, 2]), np.array([3, 4])]
expected = [tensor.Tensor(np.array([1, 2])), tensor.Tensor(np.array([3, 4]))]
result, has_array = evaluator._adapt_to_eager_mode(inputs)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _adapt_to_eager_mode of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
for res, exp in zip(result, expected):
np.testing.assert_array_equal(res.value, exp.value)
self.assertTrue(has_array)


def test_compute_num_outputs_scan(self):
schema = onnx.defs.get_schema("Scan", 9)
args = [np.array([1, 2, 3, 4])]
kwargs = {'body': onnx.helper.make_graph([], "body", [], [onnx.helper.make_tensor_value_info("output", onnx.TensorProto.FLOAT, [1])])}
expected_outputs = 1
result = evaluator.compute_num_outputs(schema, args, kwargs)
self.assertEqual(result, expected_outputs)


def test_compute_num_outputs_variable_outputs(self):
schema = onnx.defs.get_schema("Split", 13)
args = [np.array([1, 2, 3, 4]), np.array([2, 2])]
kwargs = {}
expected_outputs = 2
result = evaluator.compute_num_outputs(schema, args, kwargs)
self.assertEqual(result, expected_outputs)


def test_adapt_to_user_mode_single_numpy_array(self):
input_array = np.array([1, 2, 3])
expected = np.array([1, 2, 3])
result = evaluator._adapt_to_user_mode(input_array)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _adapt_to_user_mode of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
np.testing.assert_array_equal(result, expected)


def test_adapt_to_eager_mode_single_none(self):
input_none = None
expected = None
result, has_array = evaluator._adapt_to_eager_mode(input_none)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _adapt_to_eager_mode of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
self.assertEqual(result, expected)
self.assertFalse(has_array)


def test_adapt_to_eager_mode_single_scalar(self):
input_scalar = 5
expected = tensor.Tensor(np.array(input_scalar, dtype=np.int64))
result, has_array = evaluator._adapt_to_eager_mode(input_scalar)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _adapt_to_eager_mode of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
self.assertEqual(result, expected)
self.assertFalse(has_array)


def test_adapt_to_user_mode_tuple_of_tensors(self):
input_tensors = (tensor.Tensor(np.array([1, 2, 3])), tensor.Tensor(np.array([4, 5, 6])))
expected = (np.array([1, 2, 3]), np.array([4, 5, 6]))
result = evaluator._adapt_to_user_mode(input_tensors)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _adapt_to_user_mode of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
np.testing.assert_array_equal(result[0], expected[0])
np.testing.assert_array_equal(result[1], expected[1])


def test_unwrap_tensors_in_kwargs_mixed(self):
kwargs = {'a': tensor.Tensor(np.array([1, 2, 3])), 'b': np.array([4, 5, 6])}
expected = {'a': np.array([1, 2, 3]), 'b': np.array([4, 5, 6])}
result = evaluator._unwrap_tensors_in_kwargs(kwargs)

Check warning

Code scanning / lintrunner

PYLINT/W0212 Warning

Access to a protected member _unwrap_tensors_in_kwargs of a client class (protected-access)
See protected-access. To disable, use # pylint: disable=protected-access
np.testing.assert_array_equal(result['a'], expected['a'])
np.testing.assert_array_equal(result['b'], expected['b'])


def test_compute_num_outputs_split_no_num_outputs(self):
schema = onnx.defs.get_schema("Split", 13)
args = [np.array([1, 2, 3, 4])]
kwargs = {}
with self.assertRaises(evaluator.EagerModeError):
evaluator.compute_num_outputs(schema, args, kwargs)

if __name__ == "__main__":
unittest.main()
Loading
Loading