Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
dchourasia committed Aug 9, 2024
2 parents 08a67f7 + bb0caf9 commit f221687
Show file tree
Hide file tree
Showing 14 changed files with 159 additions and 13 deletions.
16 changes: 16 additions & 0 deletions examples/trainercontroller_configs/log_controller.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
controller_metrics:
- name: trainer_state
class: TrainingState
operations:
- name: logcontrolstep
class: LogControl
arguments:
log_format: 'This is a test log format [{event_name}] => {trainer_state}'
log_level: warning
controllers:
- name: log-controller-step
triggers:
- on_step_end
rule: 'True'
operations:
- logcontrolstep.should_log
19 changes: 18 additions & 1 deletion scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ def parse_and_validate_args():
help="Whether to load the model using Flash Attention 2",
action="store_true",
)
parser.add_argument(
"--base_model_name_or_path",
help="Base model for adapter",
)
parsed_args = parser.parse_args()

print(f"Multiclass / multioutput delimiter: {parsed_args.delimiter}")
Expand Down Expand Up @@ -446,7 +450,20 @@ def export_experiment_info(

if __name__ == "__main__":
args = parse_and_validate_args()
tuned_model = TunedCausalLM.load(args.model, use_flash_attn=args.use_flash_attn)

base_model_name_or_path = args.base_model_name_or_path
if not base_model_name_or_path:
adapter_config_path = os.path.join(args.model, "adapter_config.json")
if os.path.exists(adapter_config_path):
with open(adapter_config_path, "r", encoding="utf-8") as config_file:
adapter_config = json.load(config_file)
base_model_name_or_path = adapter_config.get("base_model_name_or_path")

tuned_model = TunedCausalLM.load(
args.model,
use_flash_attn=args.use_flash_attn,
base_model_name_or_path=base_model_name_or_path,
)
eval_data = datasets.load_dataset(
"json", data_files=args.data_path, split=args.split
)
Expand Down
1 change: 1 addition & 0 deletions tests/data/trainercontroller/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,4 @@
_DATA_DIR, "thresholded-training-loss.yaml"
)
TRAINER_CONFIG_TEST_ON_SAVE_YAML = os.path.join(_DATA_DIR, "on-save.yaml")
TRAINER_CONFIG_LOG_CONTROLLER_YAML = os.path.join(_DATA_DIR, "log_controller.yaml")
16 changes: 16 additions & 0 deletions tests/data/trainercontroller/log_controller.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
controller_metrics:
- name: trainer_state
class: TrainingState
operations:
- name: logcontrolstep
class: LogControl
arguments:
log_format: 'This is a test log format [{event_name}] => {trainer_state}'
log_level: warning
controllers:
- name: log-controller-step
triggers:
- on_step_end
rule: 'True'
operations:
- logcontrolstep.should_log
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
controller_metrics:
- name: state
- name: trainer_state
class: TrainingState
- name: training_loss
class: Loss
controllers:
- name: loss_controller
triggers:
- on_log
rule: training_loss['loss'] < 2 and state["epoch"] >= 0.5
rule: training_loss['loss'] < 2 and trainer_state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
controller_metrics:
- name: loss
- name: training_loss
class: Loss
controllers:
- name: loss_controller_wrong_os_rule
Expand Down
4 changes: 2 additions & 2 deletions tests/data/trainercontroller/on-save.yaml
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
controller_metrics:
- name: state
- name: trainer_state
class: TrainingState
controllers:
- name: stop_on_training_loss_on_save
triggers:
- on_save
rule: state["epoch"] >= 0.5
rule: trainer_state["epoch"] >= 0.5
operations:
- hfcontrols.should_training_stop
17 changes: 17 additions & 0 deletions tests/trainercontroller/test_tuning_trainercontroller.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,23 @@ def test_thresholded_training_loss_on_save():
assert control.should_training_stop is True


def test_log_controller(caplog):
"""Tests the expose metric scenario example in
`examples/trainer-controller-configs/log_controller.yaml`
"""
test_data = _setup_data()
tc_callback = tc.TrainerControllerCallback(td.TRAINER_CONFIG_LOG_CONTROLLER_YAML)
control = TrainerControl(should_log=False)
# Trigger on_init_end to perform registration of handlers to events
tc_callback.on_init_end(
args=test_data.args, state=test_data.states[2], control=control
)
tc_callback.on_step_end(
args=test_data.args, state=test_data.states[2], control=control
)
assert "This is a test log format" in caplog.text


def test_non_decreasing_training_loss():
"""Tests the non-decreasing training loss example in
`examples/trainer-controller-configs/non-decreasing-training-loss.yaml`
Expand Down
6 changes: 4 additions & 2 deletions tuning/trackers/aimstack_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def on_init_end(self, args, state, control, **kwargs):
Args:
For the arguments see reference to transformers.TrainingCallback
"""
# pylint: disable=unused-argument
self.setup() # initialize aim's run_hash
super().on_init_end(args, state, control, **kwargs)

if not self._run:
return

# Change default run hash path to output directory if not specified
if self.run_id_export_path is None:
Expand Down
6 changes: 3 additions & 3 deletions tuning/trainercontroller/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,12 +255,12 @@ def _take_control_actions(self, event_name: str, **kwargs):
for operation_action in control_action.operation_actions:
operation_action.instance.act(
action=operation_action.action,
event_name=event_name,
tc_metrics=self.metrics,
control_name=control_action.name,
log_level=control_action.config[
CONTROLLER_CONFIG_TRIGGER_LOG_LEVEL
],
event_name=event_name,
control_name=control_action.name,
**self.metrics,
**kwargs,
)

Expand Down
5 changes: 4 additions & 1 deletion tuning/trainercontroller/controllermetrics/trainingstate.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,13 @@

# Third Party
from transformers import TrainerState
from transformers.utils import logging

# Local
from tuning.trainercontroller.controllermetrics.metricshandler import MetricHandler

logger = logging.get_logger(__name__)


class TrainingState(MetricHandler):
"""Implements the controller metric which exposes the trainer state"""
Expand All @@ -49,7 +52,7 @@ def __init__(self, **kwargs):
"on_train_begin",
"on_evaluate",
],
**kwargs
**kwargs,
)

def validate(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions tuning/trainercontroller/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

# Local
from .hfcontrols import HFControls
from .logcontrol import LogControl
from .operation import Operation

# List of operation handlers
Expand All @@ -20,3 +21,4 @@ def register(cl: Type):

# Register the default operation handlers in this package here
register(HFControls)
register(LogControl)
55 changes: 55 additions & 0 deletions tuning/trainercontroller/operations/logcontrol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Third Party
from transformers import TrainingArguments
from transformers.utils import logging

# Local
from .operation import Operation

logger = logging.get_logger(__name__)
logger.setLevel(level=logging.DEBUG)


class LogControl(Operation):
"""Operation that can be used to log useful information on specific events."""

def __init__(self, log_format: str, log_level: str, **kwargs):
"""Initializes the HuggingFace controls. In this init, the fields with `should_` of the
transformers.TrainerControl data class are extracted, and for each of those fields, the
control_action() method's pointer is set, and injected as a class member function.
Args:
kwargs: List of arguments (key, value)-pairs
"""
log_levels = logging.get_log_levels_dict()
if log_level not in log_levels:
raise ValueError(
"Specified log_level [%s] is invalid for LogControl" % (log_level)
)
self.log_level = log_levels[log_level]
self.log_format = log_format
super().__init__(**kwargs)

def should_log(
self,
event_name: str = None,
control_name: str = None,
args: TrainingArguments = None,
**kwargs,
):
"""This method peeks into the stack-frame of the caller to get the action the triggered
a call to it. Using the name of the action, the value of the control is set.
Args:
control: TrainerControl. Data class for controls.
kwargs: List of arguments (key, value)-pairs
"""
log_msg = self.log_format.format(
event_name=event_name,
control_name=control_name,
args=args,
**kwargs,
)
logger.log(
self.log_level,
log_msg,
)
19 changes: 18 additions & 1 deletion tuning/trainercontroller/operations/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self, name: str, **kwargs):
every action should preceed with prefix `should_`. If so, it is treated as a valid
action.
"""
self._name = name
self.kwargs = kwargs
self.valid_actions = {}
self.name = name
self.kwargs = kwargs
Expand All @@ -26,6 +28,14 @@ def __init__(self, name: str, **kwargs):
if re.search(r"^should_.+", action_name) is not None:
self.valid_actions[action_name] = action_method

def get_name(self) -> str:
"""Returns the name of the operation.
Returns:
str
"""
return self._name

def validate(self, action: str) -> bool:
"""Validates the action by checking if it valid action or not.
Expand All @@ -38,7 +48,12 @@ def validate(self, action: str) -> bool:
return action in self.valid_actions

def act(
self, action: str, event_name: str, control_name: str, log_level: int, **kwargs
self,
action: str,
log_level: int,
event_name: str = None,
control_name: str = None,
**kwargs,
):
"""Validates the action and invokes it.
Expand All @@ -58,6 +73,8 @@ def act(
control_name,
event_name,
)
kwargs["event_name"] = event_name
kwargs["control_name"] = control_name
self.valid_actions[action](**kwargs)

def get_actions(self) -> list[str]:
Expand Down

0 comments on commit f221687

Please sign in to comment.