diff --git a/architecture_records/001-trainer-controller-framework.md b/architecture_records/001-trainer-controller-framework.md index 1bf79d67f..196f6ad21 100644 --- a/architecture_records/001-trainer-controller-framework.md +++ b/architecture_records/001-trainer-controller-framework.md @@ -98,7 +98,7 @@ controller-metrics: - name: loss class: Loss controllers: - - name: loss-controller + - name: loss_controller triggers: - on_log rule: "loss < 1.0" @@ -107,9 +107,8 @@ controllers: ``` We follow the below naming convention for the above trainer controller configuration: -1. `-` could be used in the case of key names, and name of the metric, operation or controller. This is usually to break multiple words of a name phrase. 1. Python convention for [class name](https://visualgit.readthedocs.io/en/latest/pages/naming_convention.html#classes). -1. `_` are used for events and control actions. +1. `_` should be used between words in keys, values, events and control actions. For defining custom handler classes, we have an interface defined as an abstract class as shown below, with two abstract methods, namely: `validate()` to define the validation conditions, and `compute()` to compute the metric. The `compute()` returns an `Any` type. While it could be any value, developers should keep in mind that it should be only key-value pairs that are used in the rule(s) defined in the configuration. diff --git a/examples/trainercontroller_configs/epoch-level-eval-loss-below-threshold.yaml b/examples/trainercontroller_configs/epoch-level-eval-loss-below-threshold.yaml new file mode 100644 index 000000000..d8e903294 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-eval-loss-below-threshold.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: trainer_state + class: TrainingState + - name: evalmetric + class: EvalMetrics +controllers: + - name: epoch_level_eval_loss_below_threshold + triggers: + - on_epoch_end + rule: evalmetric['eval_loss'] < 2.25 and trainer_state["epoch"] > 2 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/epoch-level-eval-loss-patience.yaml b/examples/trainercontroller_configs/epoch-level-eval-loss-patience.yaml new file mode 100644 index 000000000..d86e96a6e --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-eval-loss-patience.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: eval_loss_window + class: HistoryBasedMetric + arguments: + window_size: 2 +controllers: + - name: epoch_level_eval_loss_patience + triggers: + - on_epoch_end + rule: len(eval_loss_window["metrics"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2 + patience: + patience_threshold: 2 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/epoch-level-eval-loss.yaml b/examples/trainercontroller_configs/epoch-level-eval-loss.yaml new file mode 100644 index 000000000..ac0f15280 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-eval-loss.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: trainer_state + class: TrainingState + - name: eval_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_eval_loss + triggers: + - on_epoch_end + rule: len(eval_loss_window["metrics"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2.2 and trainer_state["epoch"] > 3 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold.yaml b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold.yaml new file mode 100644 index 000000000..a0ff37255 --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-training-loss-below-threshold.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_stop_on_training_loss_below_threshold + triggers: + - on_log + rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/epoch-level-training-loss.yaml b/examples/trainercontroller_configs/epoch-level-training-loss.yaml new file mode 100644 index 000000000..0b41f3f7b --- /dev/null +++ b/examples/trainercontroller_configs/epoch-level-training-loss.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: trainer_state + class: TrainingState + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_training_loss + triggers: + - on_epoch_end + rule: training_loss_window["training_loss"]["loss"][-1] > 2 and trainer_state["epoch"] > 3 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/loss.yaml b/examples/trainercontroller_configs/loss.yaml index dd272d21c..d7d0baa2b 100644 --- a/examples/trainercontroller_configs/loss.yaml +++ b/examples/trainercontroller_configs/loss.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller + - name: loss_controller triggers: - on_log rule: loss < 1.0 diff --git a/examples/trainercontroller_configs/non-decreasing-training-loss.yaml b/examples/trainercontroller_configs/non-decreasing-training-loss.yaml new file mode 100644 index 000000000..db504dede --- /dev/null +++ b/examples/trainercontroller_configs/non-decreasing-training-loss.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 5 +controllers: + - name: stop_on_training_loss_not_decreasing + triggers: + - on_log + rule: training_loss_window["training_loss"]["loss"][0] < training_loss_window["training_loss"]["loss"][-1] and len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/examples/trainercontroller_configs/thresholded-training-loss.yaml b/examples/trainercontroller_configs/thresholded-training-loss.yaml new file mode 100644 index 000000000..0092c0057 --- /dev/null +++ b/examples/trainercontroller_configs/thresholded-training-loss.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: stop_on_training_loss_not_decreasing + triggers: + - on_log + rule: training_loss_window["training_loss"]["loss"][-1] > 2.2 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/__init__.py b/tests/data/trainercontroller/__init__.py index 1dd93b6db..aaaeabe93 100644 --- a/tests/data/trainercontroller/__init__.py +++ b/tests/data/trainercontroller/__init__.py @@ -62,3 +62,18 @@ TRAINER_CONFIG_TEST_CUSTOM_OPERATION_INVALID_ACTION_YAML = os.path.join( _DATA_DIR, "loss_custom_operation_invalid_action.yaml" ) +TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_PATIENCE_YAML = os.path.join( + _DATA_DIR, "epoch-level-eval-loss-patience.yaml" +) +TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_YAML = os.path.join( + _DATA_DIR, "epoch-level-eval-loss.yaml" +) +TRAINER_CONFIG_TEST_EPOCH_LEVEL_TRAINING_LOSS_YAML = os.path.join( + _DATA_DIR, "epoch-level-training-loss.yaml" +) +TRAINER_CONFIG_TEST_NON_DECREASING_TRAINING_LOSS_YAML = os.path.join( + _DATA_DIR, "non-decreasing-training-loss.yaml" +) +TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML = os.path.join( + _DATA_DIR, "thresholded-training-loss.yaml" +) diff --git a/tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml b/tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml new file mode 100644 index 000000000..c0d5a191a --- /dev/null +++ b/tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: eval_loss_window + class: HistoryBasedMetric + arguments: + window_size: 2 +controllers: + - name: epoch_level_eval_loss_patience + triggers: + - on_epoch_end + rule: len(eval_loss_window["metrics"]["eval_loss"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2 + patience: + patience_threshold: 2 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/epoch-level-eval-loss.yaml b/tests/data/trainercontroller/epoch-level-eval-loss.yaml new file mode 100644 index 000000000..58b54c274 --- /dev/null +++ b/tests/data/trainercontroller/epoch-level-eval-loss.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: trainer_state + class: TrainingState + - name: eval_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_eval_loss + triggers: + - on_epoch_end + rule: len(eval_loss_window["metrics"]["eval_loss"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2 and trainer_state["epoch"] > 0.1 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/epoch-level-training-loss.yaml b/tests/data/trainercontroller/epoch-level-training-loss.yaml new file mode 100644 index 000000000..d4e56ec92 --- /dev/null +++ b/tests/data/trainercontroller/epoch-level-training-loss.yaml @@ -0,0 +1,14 @@ +controller_metrics: + - name: trainer_state + class: TrainingState + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: epoch_level_training_loss + triggers: + - on_epoch_end + rule: training_loss_window["training_loss"]["loss"][-1] < 1 and trainer_state["epoch"] >= 0.5 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/exposed_metrics.yaml b/tests/data/trainercontroller/exposed_metrics.yaml index 6fef43d68..45136e87b 100644 --- a/tests/data/trainercontroller/exposed_metrics.yaml +++ b/tests/data/trainercontroller/exposed_metrics.yaml @@ -1,10 +1,10 @@ -controller-metrics: +controller_metrics: - name: evalmetric class: EvalMetrics arguments: - source-event: on_evaluate + source_event: on_evaluate controllers: - - name: loss-controller + - name: loss_controller triggers: - on_evaluate rule: evalmetric['eval_loss'] < 2.5 diff --git a/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml b/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml index b507150d1..ea96fe4b6 100644 --- a/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml +++ b/tests/data/trainercontroller/incorrect_source_event_exposed_metrics.yaml @@ -1,10 +1,10 @@ -controller-metrics: +controller_metrics: - name: evalmetric class: EvalMetrics arguments: - source-event: on_incorrect_event + source_event: on_incorrect_event controllers: - - name: loss-controller + - name: loss_controller triggers: - on_evaluate rule: evalmetric['eval_loss'] < 2.5 diff --git a/tests/data/trainercontroller/loss_custom_metric.yaml b/tests/data/trainercontroller/loss_custom_metric.yaml index fece59d9a..7fc4c6583 100644 --- a/tests/data/trainercontroller/loss_custom_metric.yaml +++ b/tests/data/trainercontroller/loss_custom_metric.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: testflag class: CustomMetric controllers: - - name: loss-controller-custom-metric + - name: loss_controller_custom-metric triggers: - on_log rule: testflag == True diff --git a/tests/data/trainercontroller/loss_custom_operation.yaml b/tests/data/trainercontroller/loss_custom_operation.yaml index 73737f8fb..603459234 100644 --- a/tests/data/trainercontroller/loss_custom_operation.yaml +++ b/tests/data/trainercontroller/loss_custom_operation.yaml @@ -1,13 +1,13 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss operations: - - name: customoperation + - name: custom_operation class: CustomOperation controllers: - - name: loss-controller-custom-operation + - name: loss_controller_custom_operation triggers: - on_log rule: loss < 1.0 operations: - - customoperation.should_perform_action_xyz \ No newline at end of file + - custom_operation.should_perform_action_xyz \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml index 80c07f296..3dac47cb2 100644 --- a/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml +++ b/tests/data/trainercontroller/loss_custom_operation_invalid_action.yaml @@ -1,13 +1,13 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss operations: - - name: customoperation + - name: custom_operation class: CustomOperationInvalidAction controllers: - - name: loss-controller-custom-operation-invalid-action + - name: loss_controller_custom_operation_invalid_action triggers: - on_log rule: loss < 1.0 operations: - - customoperation.should_ \ No newline at end of file + - custom_operation.should_ \ No newline at end of file diff --git a/tests/data/trainercontroller/loss_invalid_metric.yaml b/tests/data/trainercontroller/loss_invalid_metric.yaml index f86de8f57..4d94878aa 100644 --- a/tests/data/trainercontroller/loss_invalid_metric.yaml +++ b/tests/data/trainercontroller/loss_invalid_metric.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: MissingMetricClass controllers: - - name: loss-controller-invalid-metric + - name: loss_controller_invalid_metric triggers: - on_log rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_invalid_operation.yaml b/tests/data/trainercontroller/loss_invalid_operation.yaml index 65aaff263..f904e27d9 100644 --- a/tests/data/trainercontroller/loss_invalid_operation.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-invalid-operation + - name: loss_controller_invalid_operation triggers: - on_log rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_invalid_operation_action.yaml b/tests/data/trainercontroller/loss_invalid_operation_action.yaml index 6f72b65ea..3015516ef 100644 --- a/tests/data/trainercontroller/loss_invalid_operation_action.yaml +++ b/tests/data/trainercontroller/loss_invalid_operation_action.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-invalid-operation-action + - name: loss_controller_invalid_operation_action triggers: - on_log rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_invalid_trigger.yaml b/tests/data/trainercontroller/loss_invalid_trigger.yaml index 5e509cbb9..382ad7783 100644 --- a/tests/data/trainercontroller/loss_invalid_trigger.yaml +++ b/tests/data/trainercontroller/loss_invalid_trigger.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-invalid-trigger + - name: loss_controller_invalid_trigger triggers: - log_it_all_incorrect_trigger_name rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_on_threshold.yaml b/tests/data/trainercontroller/loss_on_threshold.yaml index dd272d21c..d7d0baa2b 100644 --- a/tests/data/trainercontroller/loss_on_threshold.yaml +++ b/tests/data/trainercontroller/loss_on_threshold.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller + - name: loss_controller triggers: - on_log rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml index c40bb58b2..45e2a3eea 100644 --- a/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml +++ b/tests/data/trainercontroller/loss_on_threshold_with_trainer_state.yaml @@ -1,10 +1,10 @@ -controller-metrics: +controller_metrics: - name: state class: TrainingState - name: loss class: Loss controllers: - - name: loss-controller + - name: loss_controller triggers: - on_log rule: loss < 2 and state["epoch"] >= 0.5 diff --git a/tests/data/trainercontroller/loss_unavailable_metric.yaml b/tests/data/trainercontroller/loss_unavailable_metric.yaml index d50f9ea9b..055b93cf3 100644 --- a/tests/data/trainercontroller/loss_unavailable_metric.yaml +++ b/tests/data/trainercontroller/loss_unavailable_metric.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-unavailable-metric + - name: loss_controller_unavailable_metric triggers: - on_step_end rule: loss < 1.0 diff --git a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml index a2bd9e303..01495f106 100644 --- a/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml +++ b/tests/data/trainercontroller/loss_with_invalid_type_rule.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-wrong-os-rule + - name: loss_controller_wrong_os_rule triggers: - on_log rule: "2+2" diff --git a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml index a466675f6..6d5c65328 100644 --- a/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_input_rule.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-wrong-input-rule + - name: loss_controller_wrong_input_rule triggers: - on_log rule: input('Please enter your password:') diff --git a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml index 3c32e61df..badcf940a 100644 --- a/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml +++ b/tests/data/trainercontroller/loss_with_malicious_os_rule.yaml @@ -1,8 +1,8 @@ -controller-metrics: +controller_metrics: - name: loss class: Loss controllers: - - name: loss-controller-wrong-os-rule + - name: loss_controller_wrong_os_rule triggers: - on_log rule: __import__('os').system('clear') diff --git a/tests/data/trainercontroller/non-decreasing-training-loss.yaml b/tests/data/trainercontroller/non-decreasing-training-loss.yaml new file mode 100644 index 000000000..1ccfbb455 --- /dev/null +++ b/tests/data/trainercontroller/non-decreasing-training-loss.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 2 +controllers: + - name: stop_on_training_loss_not_decreasing + triggers: + - on_log + rule: training_loss_window["training_loss"]["loss"][0] < training_loss_window["training_loss"]["loss"][-1] and len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/data/trainercontroller/thresholded-training-loss.yaml b/tests/data/trainercontroller/thresholded-training-loss.yaml new file mode 100644 index 000000000..2f29bcd94 --- /dev/null +++ b/tests/data/trainercontroller/thresholded-training-loss.yaml @@ -0,0 +1,12 @@ +controller_metrics: + - name: training_loss_window + class: HistoryBasedMetric + arguments: + window_size: 1 +controllers: + - name: stop_on_training_loss_not_decreasing + triggers: + - on_log + rule: training_loss_window["training_loss"]["loss"][0] < 1 + operations: + - hfcontrols.should_training_stop \ No newline at end of file diff --git a/tests/trainercontroller/test_tuning_trainercontroller.py b/tests/trainercontroller/test_tuning_trainercontroller.py index 6788a8c61..7f98ace94 100644 --- a/tests/trainercontroller/test_tuning_trainercontroller.py +++ b/tests/trainercontroller/test_tuning_trainercontroller.py @@ -17,6 +17,7 @@ # Standard from dataclasses import dataclass +from typing import List # Third Party from simpleeval import FunctionNotDefined @@ -41,7 +42,8 @@ class InputData: """Stores the operation handler instance and corresponding action""" args: config.TrainingArguments - state: TrainerState + states: List[TrainerState] + metrics: dict def _setup_data() -> InputData: @@ -61,15 +63,42 @@ def _setup_data() -> InputData: logging_strategy=IntervalStrategy.STEPS, logging_steps=1, ), - state=TrainerState( - log_history=[ - {"loss": 2.0, "epoch": 0.1}, - {"loss": 2.1, "epoch": 0.25}, - {"loss": 1.3, "epoch": 0.5}, - {"loss": 0.9, "epoch": 0.6}, - ], - epoch=0.6, - ), + states=[ + TrainerState( + log_history=[ + {"loss": 2.0, "epoch": 0.1}, + {"loss": 2.1, "epoch": 0.25}, + ], + epoch=0.6, + global_step=1, + ), + TrainerState( + log_history=[ + {"loss": 2.0, "epoch": 0.1}, + {"loss": 2.1, "epoch": 0.25}, + {"loss": 1.3, "epoch": 0.5}, + ], + epoch=1.0, + global_step=2, + ), + TrainerState( + log_history=[ + {"loss": 2.0, "epoch": 0.1}, + {"loss": 2.1, "epoch": 0.25}, + {"loss": 1.3, "epoch": 0.5}, + {"loss": 0.9, "epoch": 0.6}, + ], + epoch=1.6, + global_step=3, + ), + ], + metrics=[ + {"eval_loss": 2.2}, + {"eval_loss": 2.1}, + {"eval_loss": 2.3}, + {"eval_loss": 2.4}, + {"eval_loss": 2.5}, + ], ) @@ -83,9 +112,142 @@ def test_loss_on_threshold(): ) control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) + assert control.should_training_stop is True + + +def test_thresholded_training_loss(): + """Tests the thresholded training loss example in + `examples/trainer-controller-configs/thresholded-training-loss.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) + assert control.should_training_stop is True + + +def test_non_decreasing_training_loss(): + """Tests the non-decreasing training loss example in + `examples/trainer-controller-configs/non-decreasing-training-loss.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_NON_DECREASING_TRAINING_LOSS_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + incremental_history = [] + original_history = test_data.states[2].log_history + for log in original_history: + incremental_history.append(log) + test_data.states[2].log_history = incremental_history + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) + if control.should_training_stop: + assert True + + +def test_epoch_level_training_loss(): + """Tests the epoch level training loss example in + `examples/trainer-controller-configs/epoch-level-training-loss.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_EPOCH_LEVEL_TRAINING_LOSS_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + incremental_history = [] + original_history = test_data.states[2].log_history + test_passes = False + for log in original_history: + incremental_history.append(log) + test_data.states[2].log_history = incremental_history + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) + tc_callback.on_epoch_end( + args=test_data.args, state=test_data.states[2], control=control + ) + if control.should_training_stop is True: + test_passes = True + assert test_passes is True + + +def test_epoch_level_eval_loss(): + """Tests the epoch level eval loss example in + `examples/trainer-controller-configs/epoch-level-eval-loss.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + tc_callback.on_evaluate( + args=test_data.args, + state=test_data.states[2], + control=control, + metrics=test_data.metrics[0], + ) + tc_callback.on_epoch_end( + args=test_data.args, state=test_data.states[2], control=control + ) + assert control.should_training_stop is True + + +def test_epoch_level_eval_loss_patience(): + """Tests the epoch level eval loss with patience threshold example in + `examples/trainer-controller-configs/epoch-level-eval-loss-patience.yaml` + """ + test_data = _setup_data() + tc_callback = tc.TrainerControllerCallback( + td.TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_PATIENCE_YAML + ) + control = TrainerControl(should_training_stop=False) + # Trigger on_init_end to perform registration of handlers to events + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) + # Trigger rule and test the condition + for metrics in test_data.metrics: + control = TrainerControl(should_training_stop=False) + tc_callback.on_evaluate( + args=test_data.args, + state=test_data.states[2], + control=control, + metrics=metrics, + ) + tc_callback.on_epoch_end( + args=test_data.args, state=test_data.states[2], control=control + ) + if control.should_training_stop: + break assert control.should_training_stop is True @@ -99,9 +261,11 @@ def test_loss_on_threshold_with_trainer_state(): ) control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) def test_exposed_metrics(): @@ -113,10 +277,12 @@ def test_exposed_metrics(): control = TrainerControl(should_training_stop=False) metrics = {"eval_loss": 2.2} # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition tc_callback.on_evaluate( - args=test_data.args, state=test_data.state, control=control, metrics=metrics + args=test_data.args, state=test_data.states[2], control=control, metrics=metrics ) assert control.should_training_stop is True @@ -134,11 +300,14 @@ def test_incorrect_source_event_exposed_metrics(): metrics = {"eval_loss": 2.2} # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition tc_callback.on_evaluate( - args=test_data.args, state=test_data.state, control=control, metrics=metrics + args=test_data.args, + state=test_data.states[2], + control=control, + metrics=metrics, ) assert ( str(exception_handler.value).strip("'") @@ -158,9 +327,11 @@ def test_custom_metric_handler(): tc_callback.register_metric_handlers([CustomMetric]) control = TrainerControl() # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True @@ -175,9 +346,11 @@ def test_custom_operation_handler(): tc_callback.register_operation_handlers([CustomOperation]) control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log(args=test_data.args, state=test_data.states[2], control=control) assert control.should_training_stop is True @@ -194,13 +367,15 @@ def test_custom_operation_invalid_action_handler(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert str(exception_handler.value).strip("'") == ( - "Invalid operation customoperation.should_ for control" - + " loss-controller-custom-operation-invalid-action" + "Invalid operation custom_operation.should_ for control" + + " loss_controller_custom_operation_invalid_action" ) @@ -216,10 +391,12 @@ def test_invalid_type_rule(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert str(exception_handler.value) == "Rule failed due to incorrect type usage" @@ -235,13 +412,15 @@ def test_malicious_os_rule(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert ( str(exception_handler.value) - == "Rule for control loss-controller-wrong-os-rule is invalid" + == "Rule for control loss_controller_wrong_os_rule is invalid" ) @@ -257,10 +436,12 @@ def test_malicious_input_rule(): with pytest.raises(FunctionNotDefined) as exception_handler: # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert ( str(exception_handler.value) == "Function 'input' not defined, for expression 'input('Please enter your password:')'." @@ -279,12 +460,14 @@ def test_invalid_trigger(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert str(exception_handler.value).strip("'") == ( - "Controller loss-controller-invalid-trigger has" + "Controller loss_controller_invalid_trigger has" + " an invalid event (log_it_all_incorrect_trigger_name)" ) @@ -301,13 +484,15 @@ def test_invalid_operation(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert str(exception_handler.value).strip("'") == ( "Invalid operation missingop.should_training_stop" - + " for control loss-controller-invalid-operation" + + " for control loss_controller_invalid_operation" ) @@ -323,13 +508,15 @@ def test_invalid_operation_action(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert str(exception_handler.value).strip("'") == ( "Invalid operation hfcontrols.missingaction" - + " for control loss-controller-invalid-operation-action" + + " for control loss_controller_invalid_operation_action" ) @@ -345,10 +532,12 @@ def test_invalid_metric(): control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events tc_callback.on_init_end( - args=test_data.args, state=test_data.state, control=control + args=test_data.args, state=test_data.states[2], control=control ) # Trigger rule and test the condition - tc_callback.on_log(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_log( + args=test_data.args, state=test_data.states[2], control=control + ) assert ( str(exception_handler.value).strip("'") == "Undefined metric handler MissingMetricClass" @@ -365,6 +554,10 @@ def test_unavailable_metric(): ) control = TrainerControl(should_training_stop=False) # Trigger on_init_end to perform registration of handlers to events - tc_callback.on_init_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_init_end( + args=test_data.args, state=test_data.states[2], control=control + ) # Trigger rule and test the condition - tc_callback.on_step_end(args=test_data.args, state=test_data.state, control=control) + tc_callback.on_step_end( + args=test_data.args, state=test_data.states[2], control=control + ) diff --git a/tuning/trainercontroller/callback.py b/tuning/trainercontroller/callback.py index b14470332..5ade20323 100644 --- a/tuning/trainercontroller/callback.py +++ b/tuning/trainercontroller/callback.py @@ -35,7 +35,7 @@ # Local from tuning.trainercontroller import controllermetrics, operations -from tuning.trainercontroller.control import Control, OperationAction +from tuning.trainercontroller.control import Control, OperationAction, Rule from tuning.trainercontroller.controllermetrics import ( handlers as default_metric_handlers, ) @@ -44,12 +44,13 @@ from tuning.trainercontroller.operations import ( operation_handlers as default_operation_handlers, ) +from tuning.trainercontroller.patience import PatienceControl from tuning.utils.evaluator import MetricUnavailableError, RuleEvaluator logger = logging.get_logger(__name__) # Configuration keys -CONTROLLER_METRICS_KEY = "controller-metrics" +CONTROLLER_METRICS_KEY = "controller_metrics" OPERATIONS_KEY = "operations" CONTROLLERS_KEY = "controllers" ARGS_KEY = "arguments" @@ -57,6 +58,8 @@ CONTROLLER_NAME_KEY = "name" CONTROLLER_CLASS_KEY = "class" CONTROLLER_RULE_KEY = "rule" +CONTROLLER_CONFIG_KEY = "config" +CONTROLLER_PATIENCE_CONFIG_KEY = "patience" CONTROLLER_TRIGGERS_KEY = "triggers" CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY @@ -222,8 +225,8 @@ def _take_control_actions(self, event_name: str, **kwargs): rule_succeeded = False try: rule_succeeded = evaluator.eval( - expr=control_action.rule_str, - previously_parsed=control_action.rule, + expr=control_action.rule.rule, + previously_parsed=control_action.rule.rule_ast, ) if not isinstance(rule_succeeded, bool): raise TypeError( @@ -251,10 +254,19 @@ def _take_control_actions(self, event_name: str, **kwargs): except MetricUnavailableError as em: logger.warning("Ignoring the rule because %s", em) continue + if ( + control_action.patience is not None + and control_action.patience.should_tolerate( + rule_outcome=rule_succeeded, + event_name=event_name, + control_name=control_action.name, + ) + ): + continue if rule_succeeded: for operation_action in control_action.operation_actions: logger.info( - "Taking %s action in %s", + "Taking [%s] action in controller [%s]", operation_action.action, control_action.name, ) @@ -393,13 +405,20 @@ def on_init_end( % (controller_name, event_name) ) # Generates the byte-code for the rule from the trainer configuration - curr_rule = controller[CONTROLLER_RULE_KEY] control = Control( name=controller[CONTROLLER_NAME_KEY], - rule_str=curr_rule, - rule=EvalWithCompoundTypes.parse(expr=curr_rule), + rule=Rule( + rule=controller_rule, + rule_ast=EvalWithCompoundTypes.parse(expr=controller_rule), + ), operation_actions=[], ) + if CONTROLLER_CONFIG_KEY in controller: + control.config = controller[CONTROLLER_CONFIG_KEY] + if CONTROLLER_PATIENCE_CONFIG_KEY in controller: + control.patience = PatienceControl( + **controller[CONTROLLER_PATIENCE_CONFIG_KEY] + ) for control_operation_name in controller_ops: if control_operation_name not in self.operation_actions: raise KeyError( diff --git a/tuning/trainercontroller/control.py b/tuning/trainercontroller/control.py index 4c8b6a6d4..e995d0f17 100644 --- a/tuning/trainercontroller/control.py +++ b/tuning/trainercontroller/control.py @@ -17,11 +17,12 @@ # Standard from dataclasses import dataclass -from typing import List, Optional +from typing import Dict, List, Optional import ast # Local from tuning.trainercontroller.operations import Operation +from tuning.trainercontroller.patience import PatienceControl @dataclass @@ -32,11 +33,22 @@ class OperationAction: action: str +@dataclass +class Rule: + """Stores the rule and its configuration""" + + rule: str + rule_ast: Optional[ + ast.AST + ] = None # stores the abstract syntax tree of the parsed rule + + @dataclass class Control: """Stores the name of control, rule byte-code corresponding actions""" name: str - rule_str: str - rule: Optional[ast.AST] = None # stores the abstract syntax tree of the parsed rule + rule: Rule + patience: Optional[PatienceControl] = None operation_actions: Optional[List[OperationAction]] = None + config: Optional[Dict] = None diff --git a/tuning/trainercontroller/controllermetrics/__init__.py b/tuning/trainercontroller/controllermetrics/__init__.py index 1c0ffe59f..1f9f76705 100644 --- a/tuning/trainercontroller/controllermetrics/__init__.py +++ b/tuning/trainercontroller/controllermetrics/__init__.py @@ -20,6 +20,7 @@ # Local from .eval_metrics import EvalMetrics +from .history_based_metrics import HistoryBasedMetric from .loss import Loss from .trainingstate import TrainingState @@ -40,3 +41,4 @@ def register(cl: Type): register(TrainingState) register(EvalMetrics) register(Loss) +register(HistoryBasedMetric) diff --git a/tuning/trainercontroller/controllermetrics/eval_metrics.py b/tuning/trainercontroller/controllermetrics/eval_metrics.py index c3f140f97..696714437 100644 --- a/tuning/trainercontroller/controllermetrics/eval_metrics.py +++ b/tuning/trainercontroller/controllermetrics/eval_metrics.py @@ -38,10 +38,10 @@ def __init__(self, **kwargs): kwargs: List of arguments (key, value)-pairs """ source_events_to_check = {"on_evaluate", "on_predict"} - source_event = kwargs.get("source-event") + source_event = kwargs.get("source_event") if source_event is None: source_event = "on_evaluate" - elif source_event in source_events_to_check: + if source_event in source_events_to_check: super().__init__( events=[ source_event, diff --git a/tuning/trainercontroller/controllermetrics/history_based_metrics.py b/tuning/trainercontroller/controllermetrics/history_based_metrics.py new file mode 100644 index 000000000..ae547d3c6 --- /dev/null +++ b/tuning/trainercontroller/controllermetrics/history_based_metrics.py @@ -0,0 +1,139 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Standard +from collections import deque +from typing import Any + +# Third Party +from transformers import TrainerState +from transformers.utils import logging + +# Local +from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler + +logger = logging.get_logger(__name__) +METRICS_KEY = "metrics" +LOG_LOSS_KEY = "loss" +TRAINING_LOSS_KEY = "training_loss" +WINDOW_SIZE = "window_size" +STEP_KEY = "steps" +EPOCH_KEY = "epoch" + + +class HistoryBasedMetric(MetricHandler): + """Implements the controller metric which evaluates loss-per-step""" + + def __init__(self, window_size=1, **kwargs): + """Initializes the metric handler, by registering the event \ + list and arguments with base handler. + + Args: + kwargs: List of arguments (key, value)-pairs + """ + self._window = { + TRAINING_LOSS_KEY: {}, + METRICS_KEY: {}, + WINDOW_SIZE: window_size, + } + super().__init__(events=["on_log", "on_evaluate"], **kwargs) + + def _add_and_slide(self, data_type: str, data: dict) -> bool: + """Add field values to vectors for each field in the data source. + + Args: + type: Data type. + data_source: Keys in data source. + + Returns: + bool + """ + data_sources = list(self._window[data_type].keys()) + for data_source in data_sources: + self._window[data_type][data_source].append(data[data_source]) + window_size = self._window[WINDOW_SIZE] + if window_size < 0: + return True + # All metrics in a data_type group are expected to computed together + if len(self._window[data_type][data_sources[0]]) < window_size: + return False + if len(self._window[data_type][data_sources[0]]) == window_size: + return True + for data_source in data_sources: + self._window[data_type][data_source].popleft() + return True + + def validate(self) -> bool: + """Validate the training arguments (e.g logging_steps) are \ + compatible with the computation of this metric. + + Returns: + bool + """ + return True + + def _create_vectors_if_not_exists(self, data_type: str, data_sources: list): + """Creates vectors for each field in the data source. + + Args: + data_type: Data type. + data_source: Keys in data source. + """ + if len(self._window[data_type]) > 0: + return + for data_source_name in data_sources: + self._window[data_type][data_source_name] = deque() + + def compute(self, state: TrainerState = None, **kwargs) -> Any: + """Exposes the window of loss and metrics values in the log. + + Args: + state: TrainerState object + kwargs: Remaining event arguments + + Returns: + Any. The exposed variables are returned here. + """ + if METRICS_KEY in kwargs: + metrics = kwargs[METRICS_KEY] + metrics[STEP_KEY] = state.global_step + metrics[EPOCH_KEY] = state.epoch + self._create_vectors_if_not_exists(METRICS_KEY, list(metrics.keys())) + self._add_and_slide(METRICS_KEY, metrics) + else: + self._create_vectors_if_not_exists( + TRAINING_LOSS_KEY, [LOG_LOSS_KEY, STEP_KEY, EPOCH_KEY] + ) + size_of_log_history = len(state.log_history) + for i in range(size_of_log_history - 1, -1, -1): + log = state.log_history[i] + if LOG_LOSS_KEY in log: + data = { + LOG_LOSS_KEY: float(log[LOG_LOSS_KEY]), + STEP_KEY: state.global_step, + EPOCH_KEY: float(log[EPOCH_KEY]), + } + loss_data = self._window[TRAINING_LOSS_KEY][LOG_LOSS_KEY] + epoch_data = self._window[TRAINING_LOSS_KEY][EPOCH_KEY] + if ( + len(loss_data) == 0 + or loss_data[-1] != data[LOG_LOSS_KEY] + or epoch_data[-1] != data[EPOCH_KEY] + ): + self._add_and_slide(TRAINING_LOSS_KEY, data) + break + return self._window diff --git a/tuning/trainercontroller/patience.py b/tuning/trainercontroller/patience.py new file mode 100644 index 000000000..b8098fdf0 --- /dev/null +++ b/tuning/trainercontroller/patience.py @@ -0,0 +1,76 @@ +# Copyright The IBM Tuning Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# SPDX-License-Identifier: Apache-2.0 +# https://spdx.dev/learn/handling-license-info/ + +# Third Party +from transformers.utils import logging + +# Resets the patience if the rule outcome happens to be false. +# Here, the expectation is to have unbroken "True"s for patience +# to be up-countered. +# E.g. For patience threshold, patience_threshold=3, rule outcome +# has to be T, T, T, T (each is an event +# then patience is reset at the third event when outcome is F. +MODE_RESET_ON_FAILURE = "reset_on_failure" + +# This mode does not reset patience. E.g if rule outcome is T, T, F, T, T, +# then the patience counter is not reset at F. Instead, the patience threshold +# will be exceeded afer the fifth event. +MODE_NO_RESET_ON_FAILURE = "no_reset_on_failure" + +logger = logging.get_logger(__name__) + + +class PatienceControl: + """Implements the patience control for every rule""" + + # pylint: disable=unused-argument + def __init__(self, patience_threshold=1, mode=MODE_RESET_ON_FAILURE, **kwargs): + self._patience_threshold = patience_threshold + self._patience_counter = 0 + self._mode = mode + + def should_tolerate( + self, rule_outcome: bool, event_name=None, control_name=None, **kwargs + ) -> bool: + if rule_outcome: + self._patience_counter = self._patience_counter + 1 + elif self._mode == MODE_RESET_ON_FAILURE: + self._patience_counter = 0 + if self._patience_counter <= self._patience_threshold: + logger.debug( + "Control {} triggered on event {}: " + "Enforcing patience [patience_counter = {:.2f}, " + "patience_threshold = {:.2f}]".format( + control_name, + event_name, + self._patience_counter, + self._patience_threshold, + ) + ) + return True + logger.debug( + "Control {} triggered on event {}: " + "Exceeded patience [patience_counter = {:.2f}, " + "patience_threshold = {:.2f}]".format( + control_name, + event_name, + self._patience_counter, + self._patience_threshold, + ) + ) + self._patience_counter = 0 + return False