16
16
# https://spdx.dev/learn/handling-license-info/
17
17
18
18
# Standard
19
- from importlib import resources as impresources
20
19
from typing import Dict , List , Union
21
20
import inspect
22
21
import os
34
33
import yaml
35
34
36
35
# Local
37
- from tuning .trainercontroller import controllermetrics , operations
38
36
from tuning .trainercontroller .control import Control , OperationAction , Rule
39
37
from tuning .trainercontroller .controllermetrics import (
40
38
handlers as default_metric_handlers ,
63
61
CONTROLLER_TRIGGERS_KEY = "triggers"
64
62
CONTROLLER_OPERATIONS_KEY = OPERATIONS_KEY
65
63
64
+ # Default operations / metrics to register
65
+ DEFAULT_OPERATIONS = {"operations" : [{"name" : "hfcontrols" , "class" : "HFControls" }]}
66
+ DEFAULT_METRICS = {}
66
67
67
68
# pylint: disable=too-many-instance-attributes
68
69
class TrainerControllerCallback (TrainerCallback ):
@@ -102,23 +103,15 @@ def __init__(self, trainer_controller_config: Union[dict, str]):
102
103
if OPERATIONS_KEY not in self .trainer_controller_config :
103
104
self .trainer_controller_config [OPERATIONS_KEY ] = []
104
105
105
- # Initialize the list of metrics from default `metrics.yaml` in the \
106
- # controllermetric package. In addition, any metrics mentioned in \
107
- # the trainer controller config are added to this list.
108
- default_metrics_config_yaml = (
109
- impresources .files (controllermetrics ) / "metrics.yaml"
110
- )
111
- with default_metrics_config_yaml .open ("r" ) as f :
112
- default_metrics_config = yaml .safe_load (f )
113
106
if (
114
- default_metrics_config is not None
115
- and CONTROLLER_METRICS_KEY in default_metrics_config
116
- and len (default_metrics_config [CONTROLLER_METRICS_KEY ]) > 0
107
+ DEFAULT_METRICS
108
+ and CONTROLLER_METRICS_KEY in DEFAULT_METRICS
109
+ and len (DEFAULT_METRICS [CONTROLLER_METRICS_KEY ]) > 0
117
110
):
118
111
self_controller_metrics = self .trainer_controller_config [
119
112
CONTROLLER_METRICS_KEY
120
113
]
121
- default_controller_metrics : list [dict ] = default_metrics_config [
114
+ default_controller_metrics : list [dict ] = DEFAULT_METRICS [
122
115
CONTROLLER_METRICS_KEY
123
116
]
124
117
for metric_obj in default_controller_metrics :
@@ -131,21 +124,13 @@ def __init__(self, trainer_controller_config: Union[dict, str]):
131
124
if not found :
132
125
self_controller_metrics .append (metric_obj )
133
126
134
- # Initialize the list of operations from default `operations.yaml` \
135
- # in the operations package. In addition, any operations mentioned \
136
- # in the trainer controller config are added to this list.
137
- default_operations_config_yaml = (
138
- impresources .files (operations ) / "operations.yaml"
139
- )
140
- with default_operations_config_yaml .open ("r" ) as f :
141
- default_operations_config = yaml .safe_load (f )
142
127
if (
143
- default_operations_config is not None
144
- and OPERATIONS_KEY in default_operations_config
145
- and len (default_operations_config [OPERATIONS_KEY ]) > 0
128
+ DEFAULT_OPERATIONS
129
+ and OPERATIONS_KEY in DEFAULT_OPERATIONS
130
+ and len (DEFAULT_OPERATIONS [OPERATIONS_KEY ]) > 0
146
131
):
147
132
self_controller_operations = self .trainer_controller_config [OPERATIONS_KEY ]
148
- default_controller_operations : list [dict ] = default_operations_config [
133
+ default_controller_operations : list [dict ] = DEFAULT_OPERATIONS [
149
134
OPERATIONS_KEY
150
135
]
151
136
for op_obj in default_controller_operations :
0 commit comments