Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Jul 11, 2024
2 parents 13fd496 + 85f32cb commit e43c44e
Show file tree
Hide file tree
Showing 37 changed files with 711 additions and 100 deletions.
5 changes: 2 additions & 3 deletions architecture_records/001-trainer-controller-framework.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ controller-metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_log
rule: "loss < 1.0"
Expand All @@ -107,9 +107,8 @@ controllers:
```
We follow the below naming convention for the above trainer controller configuration:
1. `-` could be used in the case of key names, and name of the metric, operation or controller. This is usually to break multiple words of a name phrase.
1. Python convention for [class name](https://visualgit.readthedocs.io/en/latest/pages/naming_convention.html#classes).
1. `_` are used for events and control actions.
1. `_` should be used between words in keys, values, events and control actions.

For defining custom handler classes, we have an interface defined as an abstract class as shown below, with two abstract methods, namely: `validate()` to define the validation conditions, and `compute()` to compute the metric. The `compute()` returns an `Any` type. While it could be any value, developers should keep in mind that it should be only key-value pairs that are used in the rule(s) defined in the configuration.

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
controller_metrics:
- name: trainer_state
class: TrainingState
- name: evalmetric
class: EvalMetrics
controllers:
- name: epoch_level_eval_loss_below_threshold
triggers:
- on_epoch_end
rule: evalmetric['eval_loss'] < 2.25 and trainer_state["epoch"] > 2
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: eval_loss_window
class: HistoryBasedMetric
arguments:
window_size: 2
controllers:
- name: epoch_level_eval_loss_patience
triggers:
- on_epoch_end
rule: len(eval_loss_window["metrics"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2
patience:
patience_threshold: 2
operations:
- hfcontrols.should_training_stop
14 changes: 14 additions & 0 deletions examples/trainercontroller_configs/epoch-level-eval-loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: trainer_state
class: TrainingState
- name: eval_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_eval_loss
triggers:
- on_epoch_end
rule: len(eval_loss_window["metrics"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2.2 and trainer_state["epoch"] > 3
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
controller_metrics:
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_stop_on_training_loss_below_threshold
triggers:
- on_log
rule: len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"] and training_loss_window["training_loss"]["loss"][0] < 2.2 and training_loss_window["training_loss"]["epoch"][0] > 2
operations:
- hfcontrols.should_training_stop
14 changes: 14 additions & 0 deletions examples/trainercontroller_configs/epoch-level-training-loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: trainer_state
class: TrainingState
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_training_loss
triggers:
- on_epoch_end
rule: training_loss_window["training_loss"]["loss"][-1] > 2 and trainer_state["epoch"] > 3
operations:
- hfcontrols.should_training_stop
4 changes: 2 additions & 2 deletions examples/trainercontroller_configs/loss.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
controller_metrics:
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 5
controllers:
- name: stop_on_training_loss_not_decreasing
triggers:
- on_log
rule: training_loss_window["training_loss"]["loss"][0] < training_loss_window["training_loss"]["loss"][-1] and len(training_loss_window["training_loss"]["loss"]) == training_loss_window["window_size"]
operations:
- hfcontrols.should_training_stop
12 changes: 12 additions & 0 deletions examples/trainercontroller_configs/thresholded-training-loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
controller_metrics:
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: stop_on_training_loss_not_decreasing
triggers:
- on_log
rule: training_loss_window["training_loss"]["loss"][-1] > 2.2
operations:
- hfcontrols.should_training_stop
15 changes: 15 additions & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,18 @@
TRAINER_CONFIG_TEST_CUSTOM_OPERATION_INVALID_ACTION_YAML = os.path.join(
_DATA_DIR, "loss_custom_operation_invalid_action.yaml"
)
TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_PATIENCE_YAML = os.path.join(
_DATA_DIR, "epoch-level-eval-loss-patience.yaml"
)
TRAINER_CONFIG_TEST_EPOCH_LEVEL_EVAL_LOSS_YAML = os.path.join(
_DATA_DIR, "epoch-level-eval-loss.yaml"
)
TRAINER_CONFIG_TEST_EPOCH_LEVEL_TRAINING_LOSS_YAML = os.path.join(
_DATA_DIR, "epoch-level-training-loss.yaml"
)
TRAINER_CONFIG_TEST_NON_DECREASING_TRAINING_LOSS_YAML = os.path.join(
_DATA_DIR, "non-decreasing-training-loss.yaml"
)
TRAINER_CONFIG_TEST_THRESHOLDED_TRAINING_LOSS_YAML = os.path.join(
_DATA_DIR, "thresholded-training-loss.yaml"
)
14 changes: 14 additions & 0 deletions tests/data/trainercontroller/epoch-level-eval-loss-patience.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: eval_loss_window
class: HistoryBasedMetric
arguments:
window_size: 2
controllers:
- name: epoch_level_eval_loss_patience
triggers:
- on_epoch_end
rule: len(eval_loss_window["metrics"]["eval_loss"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2
patience:
patience_threshold: 2
operations:
- hfcontrols.should_training_stop
14 changes: 14 additions & 0 deletions tests/data/trainercontroller/epoch-level-eval-loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: trainer_state
class: TrainingState
- name: eval_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_eval_loss
triggers:
- on_epoch_end
rule: len(eval_loss_window["metrics"]["eval_loss"]) > 0 and eval_loss_window["metrics"]["eval_loss"][-1] > 2 and trainer_state["epoch"] > 0.1
operations:
- hfcontrols.should_training_stop
14 changes: 14 additions & 0 deletions tests/data/trainercontroller/epoch-level-training-loss.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
controller_metrics:
- name: trainer_state
class: TrainingState
- name: training_loss_window
class: HistoryBasedMetric
arguments:
window_size: 1
controllers:
- name: epoch_level_training_loss
triggers:
- on_epoch_end
rule: training_loss_window["training_loss"]["loss"][-1] < 1 and trainer_state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
6 changes: 3 additions & 3 deletions tests/data/trainercontroller/exposed_metrics.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller-metrics:
controller_metrics:
- name: evalmetric
class: EvalMetrics
arguments:
source-event: on_evaluate
source_event: on_evaluate
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_evaluate
rule: evalmetric['eval_loss'] < 2.5
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller-metrics:
controller_metrics:
- name: evalmetric
class: EvalMetrics
arguments:
source-event: on_incorrect_event
source_event: on_incorrect_event
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_evaluate
rule: evalmetric['eval_loss'] < 2.5
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_custom_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: testflag
class: CustomMetric
controllers:
- name: loss-controller-custom-metric
- name: loss_controller_custom-metric
triggers:
- on_log
rule: testflag == True
Expand Down
8 changes: 4 additions & 4 deletions tests/data/trainercontroller/loss_custom_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
operations:
- name: customoperation
- name: custom_operation
class: CustomOperation
controllers:
- name: loss-controller-custom-operation
- name: loss_controller_custom_operation
triggers:
- on_log
rule: loss < 1.0
operations:
- customoperation.should_perform_action_xyz
- custom_operation.should_perform_action_xyz
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
operations:
- name: customoperation
- name: custom_operation
class: CustomOperationInvalidAction
controllers:
- name: loss-controller-custom-operation-invalid-action
- name: loss_controller_custom_operation_invalid_action
triggers:
- on_log
rule: loss < 1.0
operations:
- customoperation.should_
- custom_operation.should_
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: MissingMetricClass
controllers:
- name: loss-controller-invalid-metric
- name: loss_controller_invalid_metric
triggers:
- on_log
rule: loss < 1.0
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_operation.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-invalid-operation
- name: loss_controller_invalid_operation
triggers:
- on_log
rule: loss < 1.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-invalid-operation-action
- name: loss_controller_invalid_operation_action
triggers:
- on_log
rule: loss < 1.0
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_invalid_trigger.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-invalid-trigger
- name: loss_controller_invalid_trigger
triggers:
- log_it_all_incorrect_trigger_name
rule: loss < 1.0
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_on_threshold.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_log
rule: loss < 1.0
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller-metrics:
controller_metrics:
- name: state
class: TrainingState
- name: loss
class: Loss
controllers:
- name: loss-controller
- name: loss_controller
triggers:
- on_log
rule: loss < 2 and state["epoch"] >= 0.5
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_unavailable_metric.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-unavailable-metric
- name: loss_controller_unavailable_metric
triggers:
- on_step_end
rule: loss < 1.0
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_with_invalid_type_rule.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-wrong-os-rule
- name: loss_controller_wrong_os_rule
triggers:
- on_log
rule: "2+2"
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-wrong-input-rule
- name: loss_controller_wrong_input_rule
triggers:
- on_log
rule: input('Please enter your password:')
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/loss_with_malicious_os_rule.yaml
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
controller-metrics:
controller_metrics:
- name: loss
class: Loss
controllers:
- name: loss-controller-wrong-os-rule
- name: loss_controller_wrong_os_rule
triggers:
- on_log
rule: __import__('os').system('clear')
Expand Down
Loading

0 comments on commit e43c44e

Please sign in to comment.