From 06e8cbcf5570253cb059d211b9697c312aef88f9 Mon Sep 17 00:00:00 2001 From: Hari Date: Fri, 5 Jul 2024 13:53:28 +0530 Subject: [PATCH] feat: support some metrics being 'None' without stopping training (#169) Some metrics may not be available at the time of rule evaluation. Add some more unit tests for the same conditions. Signed-off-by: Harikrishnan Balagopal Signed-off-by: Mehant Kammakomati --- tests/data/trainercontroller/__init__.py | 3 + .../loss_unavailable_metric.yaml | 10 + tests/trainercontroller/__init__.py | 13 ++ tests/trainercontroller/custom_metric.py | 10 +- tests/trainercontroller/custom_operation.py | 10 +- .../custom_operation_invalid_action.py | 10 +- .../test_tuning_trainercontroller.py | 52 +++-- tests/utils/test_evaluator.py | 191 +++++++++++++++++- tuning/trainercontroller/callback.py | 10 +- tuning/utils/evaluator.py | 149 ++++++++++++-- 10 files changed, 401 insertions(+), 57 deletions(-) create mode 100644 tests/data/trainercontroller/loss_unavailable_metric.yaml create mode 100644 tests/trainercontroller/__init__.py diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index a18d746d2..1dd93b6db 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -50,6 +50,9 @@ TRAINER_CONFIG_TEST_INVALID_METRIC_YAML = os.path.join( _DATA_DIR, "loss_invalid_metric.yaml" ) +TRAINER_CONFIG_TEST_UNAVAILABLE_METRIC_YAML = os.path.join( + _DATA_DIR, "loss_unavailable_metric.yaml" +) TRAINER_CONFIG_TEST_CUSTOM_METRIC_YAML = os.path.join( _DATA_DIR, "loss_custom_metric.yaml" ) diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/data/trainercontroller/loss_unavailable_metric.yaml new file mode 100644 index 000000000..d50f9ea9b --- /dev/null +++ b/tests/data/trainercontroller/loss_unavailable_metric.yaml @@ -0,0 +1,10 @@ +controller-metrics: + - name: loss + class: Loss +controllers: + - name: loss-controller-unavailable-metric + triggers: + - on_step_end + rule: loss < 1.0 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/trainercontroller/__init__.py b/tests/trainercontroller/__init__.py new file mode 100644 index 000000000..38a9531ef --- /dev/null +++ b/tests/trainercontroller/__init__.py @@ -0,0 +1,13 @@ +# Copyright The FMS HF Tuning Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/trainercontroller/custom_metric.py b/tests/trainercontroller/custom_metric.py index 83b6acc53..5fcc439f1 100644 --- a/tests/trainercontroller/custom_metric.py +++ b/tests/trainercontroller/custom_metric.py @@ -16,12 +16,10 @@ # https://spdx.dev/learn/handling-license-info/ # Standard -from dataclasses import dataclass from typing import Any # Third Party from transformers import TrainerState -import pytest # Local from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler @@ -31,7 +29,8 @@ class CustomMetric(MetricHandler): """Implements a custom metric for testing""" def __init__(self, **kwargs): - """Initializes the metric handler, by registering the event list and arguments with base handler. + """Initializes the metric handler, + by registering the event list and arguments with base handler. Args: kwargs: List of arguments (key, value)-pairs @@ -39,14 +38,15 @@ def __init__(self, **kwargs): super().__init__(events=["on_log"], **kwargs) def validate(self) -> bool: - """Validate the training arguments (e.g logging_steps) are compatible with the computation of this metric. + """Validate the training arguments (e.g logging_steps) + are compatible with the computation of this metric. Returns: bool """ return True - def compute(self, state: TrainerState = None, **kwargs) -> Any: + def compute(self, _: TrainerState = None, **__) -> Any: """Just returns True (for testing purposes only). Args: diff --git a/tests/trainercontroller/custom_operation.py b/tests/trainercontroller/custom_operation.py index b09ff91de..2c402fa96 100644 --- a/tests/trainercontroller/custom_operation.py +++ b/tests/trainercontroller/custom_operation.py @@ -15,13 +15,9 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ -# Standard -from dataclasses import dataclass -from typing import Any # Third Party -from transformers import TrainerControl, TrainerState -import pytest +from transformers import TrainerControl # Local from tuning.trainercontroller.operations import Operation @@ -30,14 +26,14 @@ class CustomOperation(Operation): """Implements a custom operation for testing""" - def __init__(self, **kwargs): + def __init__(self, **_): """Initializes the custom operation class. Args: kwargs: List of arguments (key, value)-pairs """ super().__init__() - def should_perform_action_xyz(self, control: TrainerControl, **kwargs): + def should_perform_action_xyz(self, control: TrainerControl, **_): """This method performs a set training stop flag action. Args: diff --git a/tests/trainercontroller/custom_operation_invalid_action.py b/tests/trainercontroller/custom_operation_invalid_action.py index 29b447bef..5c04199d3 100644 --- a/tests/trainercontroller/custom_operation_invalid_action.py +++ b/tests/trainercontroller/custom_operation_invalid_action.py @@ -15,13 +15,9 @@ # SPDX-License-Identifier: Apache-2.0 # https://spdx.dev/learn/handling-license-info/ -# Standard -from dataclasses import dataclass -from typing import Any # Third Party -from transformers import TrainerControl, TrainerState -import pytest +from transformers import TrainerControl # Local from tuning.trainercontroller.operations import Operation @@ -30,14 +26,14 @@ class CustomOperationInvalidAction(Operation): """Implements a custom operation for testing""" - def __init__(self, **kwargs): + def __init__(self, **_): """Initializes the custom operation class. Args: kwargs: List of arguments (key, value)-pairs """ super().__init__() - def should_(self, control: TrainerControl, **kwargs): + def should_(self, control: TrainerControl, **_): """This method defines an action within an invalid name. Args: diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index c572a9c3f..6788a8c61 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -46,7 +46,8 @@ class InputData: def _setup_data() -> InputData: """ - Sets up the test data for the test cases. This includes the logs, arguments for training and state + Sets up the test data for the test cases. + This includes the logs, arguments for training and state of the training. Returns: @@ -85,7 +86,7 @@ def test_loss_on_threshold(): tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert control.should_training_stop == True + assert control.should_training_stop is True def test_loss_on_threshold_with_trainer_state(): @@ -117,7 +118,7 @@ def test_exposed_metrics(): tc_callback.on_evaluate( args=test_data.args, state=test_data.state, control=control, metrics=metrics ) - assert control.should_training_stop == True + assert control.should_training_stop is True def test_incorrect_source_event_exposed_metrics(): @@ -143,7 +144,7 @@ def test_incorrect_source_event_exposed_metrics(): str(exception_handler.value).strip("'") == "Specified source event [on_incorrect_event] is invalid for EvalMetrics" ) - assert control.should_training_stop == True + assert control.should_training_stop is True def test_custom_metric_handler(): @@ -160,7 +161,7 @@ def test_custom_metric_handler(): tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert control.should_training_stop == True + assert control.should_training_stop is True def test_custom_operation_handler(): @@ -177,7 +178,7 @@ def test_custom_operation_handler(): tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert control.should_training_stop == True + assert control.should_training_stop is True def test_custom_operation_invalid_action_handler(): @@ -197,9 +198,9 @@ def test_custom_operation_invalid_action_handler(): ) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert ( - str(exception_handler.value).strip("'") - == "Invalid operation customoperation.should_ for control loss-controller-custom-operation-invalid-action" + assert str(exception_handler.value).strip("'") == ( + "Invalid operation customoperation.should_ for control" + + " loss-controller-custom-operation-invalid-action" ) @@ -282,9 +283,9 @@ def test_invalid_trigger(): ) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert ( - str(exception_handler.value).strip("'") - == "Controller loss-controller-invalid-trigger has an invalid event (log_it_all_incorrect_trigger_name)" + assert str(exception_handler.value).strip("'") == ( + "Controller loss-controller-invalid-trigger has" + + " an invalid event (log_it_all_incorrect_trigger_name)" ) @@ -304,9 +305,9 @@ def test_invalid_operation(): ) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert ( - str(exception_handler.value).strip("'") - == "Invalid operation missingop.should_training_stop for control loss-controller-invalid-operation" + assert str(exception_handler.value).strip("'") == ( + "Invalid operation missingop.should_training_stop" + + " for control loss-controller-invalid-operation" ) @@ -326,9 +327,9 @@ def test_invalid_operation_action(): ) # Trigger rule and test the condition tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) - assert ( - str(exception_handler.value).strip("'") - == "Invalid operation hfcontrols.missingaction for control loss-controller-invalid-operation-action" + assert str(exception_handler.value).strip("'") == ( + "Invalid operation hfcontrols.missingaction" + + " for control loss-controller-invalid-operation-action" ) @@ -352,3 +353,18 @@ def test_invalid_metric(): str(exception_handler.value).strip("'") == "Undefined metric handler MissingMetricClass" ) + + +def test_unavailable_metric(): + """Tests the invalid metric scenario in the controller. Uses: + `examples/trainer-controller-configs/loss_invalid_metric.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_UNAVAILABLE_METRIC_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + # Trigger rule and test the condition + tc_callback.on_step_end(args=test_data.args, state=test_data.state, control=control) diff --git a/tests/utils/test_evaluator.py b/tests/utils/test_evaluator.py index 9bb2e4fad..87fd65aec 100644 --- a/tests/utils/test_evaluator.py +++ b/tests/utils/test_evaluator.py @@ -23,11 +23,15 @@ import pytest # Local -from tuning.utils.evaluator import get_evaluator +from tuning.utils.evaluator import RuleEvaluator def test_mailicious_inputs_to_eval(): - """Tests the malicious rules""" + """Tests the malicious rules + + Each test case has the format: + (validation_error: str, expected_rule_is_true: bool, rule: str) + """ rules: list[Tuple[str, bool, str]] = [ # Valid rules ("", False, "flags['is_training'] == False"), @@ -46,12 +50,17 @@ def test_mailicious_inputs_to_eval(): ("", False, "(loss*loss)*loss < 1.0"), ("", True, "int(''.join(['3', '4'])) < loss"), ("", True, "loss < 9**9"), - ("", False, "loss < sqrt(xs[0]*xs[0] + xs[1]*xs[1])"), + ("", False, "loss < math_sqrt(xs[0]*xs[0] + xs[1]*xs[1])"), ("", True, "len(xs) > 2"), ("", True, "loss < abs(-100)"), ("", True, "loss == flags.aaa.bbb[0].ccc"), ("", True, "array3d[0][1][1] == 4"), ("", True, "numpyarray[0][1][1] == 4"), + ("", True, "unavailablemetric == None"), + ("", False, "unavailablemetric != None"), + ("", False, "loss < 2.0 if unavailablemetric == None else loss > 0.0"), + ("", True, "loss < 2.0 if unavailablemetric != None else loss > 0.0"), + ("", True, "False if loss == None else loss > 0.0"), ( "", True, @@ -127,6 +136,177 @@ def test_mailicious_inputs_to_eval(): True, "mymetric2(loss) > loss", ), + ( + "'<' not supported between instances of 'NoneType' and 'float'", + True, + "None < 2.0", + ), + ( + "'nonexistentmetric' is not defined for expression 'nonexistentmetric < 3.0'", + True, + "nonexistentmetric < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric <= 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric == 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric != 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric > 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric >= 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric + unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric - unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric * unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric / unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric // unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + r"(unavailablemetric % unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric ** unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric << unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric >> unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric & unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric ^ unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(unavailablemetric | unavailablemetric) < 2.0", + ), + # https://docs.python.org/3/reference/datamodel.html#object.__radd__ + ( + "unsupported operand type(s) for +: 'NoneType' and 'UnavailableMetric'", + True, + "(None + unavailablemetric) < 2.0", + ), + ( + "unsupported operand type(s) for -: 'NoneType' and 'UnavailableMetric'", + True, + "(None - unavailablemetric) < 2.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "xs[unavailablemetric] < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "unavailablemetric[0] < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "int(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "float(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "-unavailablemetric < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "+unavailablemetric < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "abs(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "(~unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "round(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "math_trunc(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "math_floor(unavailablemetric) < 3.0", + ), + ( + "The metric 'unavailablemetric' is not available", + True, + "math_ceil(unavailablemetric) < 3.0", + ), ] metrics = { "loss": 42.0, @@ -143,9 +323,10 @@ def test_mailicious_inputs_to_eval(): ], ], "numpyarray": (np.arange(8).reshape((2, 2, 2)) + 1), + "unavailablemetric": None, } - evaluator = get_evaluator(metrics=metrics) + evaluator = RuleEvaluator(metrics=metrics) for validation_error, expected_rule_is_true, rule in rules: rule_parsed = evaluator.parse(expr=rule) @@ -156,7 +337,7 @@ def test_mailicious_inputs_to_eval(): ) assert ( actual_rule_is_true == expected_rule_is_true - ), "failed to execute the rule" + ), f"failed to execute the rule: '{rule}'" else: with pytest.raises(Exception) as exception_handler: evaluator.eval( diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index d30821f19..b14470332 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -44,7 +44,7 @@ from tuning.trainercontroller.operations import ( operation_handlers as default_operation_handlers, ) -from tuning.utils.evaluator import get_evaluator +from tuning.utils.evaluator import MetricUnavailableError, RuleEvaluator logger = logging.get_logger(__name__) @@ -217,7 +217,7 @@ def _take_control_actions(self, event_name: str, **kwargs): kwargs: List of arguments (key, value)-pairs. """ if event_name in self.control_actions_on_event: - evaluator = get_evaluator(metrics=self.metrics) + evaluator = RuleEvaluator(metrics=self.metrics) for control_action in self.control_actions_on_event[event_name]: rule_succeeded = False try: @@ -248,6 +248,9 @@ def _take_control_actions(self, event_name: str, **kwargs): raise NotImplementedError( "Rule failed because it uses some unsupported features" ) from ef + except MetricUnavailableError as em: + logger.warning("Ignoring the rule because %s", em) + continue if rule_succeeded: for operation_action in control_action.operation_actions: logger.info( @@ -324,6 +327,9 @@ def on_init_end( metric_handler = metric_handler_class( name=metric_name, **metric_args, **kwargs ) + # Initialize the metric with a None value so that + # the evaluator knows that the metric is unavailable. + self.metrics[metric_handler.get_name()] = None # Add metric instances to the events. for event_name in metric_handler.get_events(): if event_name in self.valid_events: diff --git a/tuning/utils/evaluator.py b/tuning/utils/evaluator.py index 42095e70c..ec0c306b8 100644 --- a/tuning/utils/evaluator.py +++ b/tuning/utils/evaluator.py @@ -1,20 +1,143 @@ # Standard -from math import sqrt +import math # Third Party from simpleeval import DEFAULT_FUNCTIONS, DEFAULT_NAMES, EvalWithCompoundTypes -def get_evaluator(metrics: dict) -> EvalWithCompoundTypes: +class MetricUnavailableError(Exception): + def __init__(self, name): + super().__init__(f"The metric '{name}' is not available") + self.name = name + + +class UnavailableMetric: + def __init__(self, name: str) -> None: + self.err = MetricUnavailableError(name=name) + + def raise_error(self): + raise self.err + + # https://docs.python.org/3/reference/datamodel.html#object.__lt__ + def __lt__(self, _): + raise self.err + + def __le__(self, _): + raise self.err + + def __eq__(self, other): + if other is None: + return True + raise self.err + + # Use the default implementation + # def __ne__(self, _): + # raise self.err + + def __gt__(self, _): + raise self.err + + def __ge__(self, _): + raise self.err + + def __getitem__(self, _): + raise self.err + + # https://docs.python.org/3/reference/datamodel.html#object.__add__ + def __add__(self, _): + raise self.err + + def __sub__(self, _): + raise self.err + + def __mul__(self, _): + raise self.err + + def __truediv__(self, _): + raise self.err + + def __floordiv__(self, _): + raise self.err + + def __mod__(self, _): + raise self.err + + def __and__(self, _): + raise self.err + + def __xor__(self, _): + raise self.err + + def __or__(self, _): + raise self.err + + # https://docs.python.org/3/reference/datamodel.html#object.__neg__ + def __neg__(self): + raise self.err + + def __pos__(self): + raise self.err + + def __abs__(self): + raise self.err + + def __invert__(self): + raise self.err + + # https://docs.python.org/3/reference/datamodel.html#object.__int__ + def __int__(self): + raise self.err + + def __float__(self): + raise self.err + + # https://docs.python.org/3/reference/datamodel.html#object.__round__ + def __round__(self, _=None): + raise self.err + + def __trunc__(self): + raise self.err + + def __floor__(self): + raise self.err + + def __ceil__(self): + raise self.err + + +class RuleEvaluator(EvalWithCompoundTypes): """Returns an evaluator that can be used to evaluate simple Python expressions.""" - all_names = { - **metrics, - **DEFAULT_NAMES.copy(), - } - all_funcs = { - "abs": abs, - "len": len, - "sqrt": sqrt, - **DEFAULT_FUNCTIONS.copy(), - } - return EvalWithCompoundTypes(functions=all_funcs, names=all_names) + + def __init__(self, metrics: dict): + all_names = { + **metrics, + **DEFAULT_NAMES.copy(), + } + all_funcs = { + "abs": abs, + "len": len, + "round": round, + "math_trunc": math.trunc, + "math_floor": math.floor, + "math_ceil": math.ceil, + "math_sqrt": math.sqrt, + **DEFAULT_FUNCTIONS.copy(), + } + super().__init__(functions=all_funcs, names=all_names) + self.metrics = metrics + + def _eval_name(self, node): + name = node.id + if ( + isinstance(name, str) + and name in self.metrics + and self.metrics[name] is None + ): + return UnavailableMetric(name=name) + return super()._eval_name(node) + + def _eval_subscript(self, node): + key = self._eval(node.slice) + if isinstance(key, UnavailableMetric): + key.raise_error() + return super()._eval_subscript(node)