Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add WandB Tracker #342

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions tuning/config/tracker_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

# Standard
from dataclasses import dataclass
from typing import List, Optional


@dataclass
Expand Down Expand Up @@ -62,6 +63,11 @@ def __post_init__(self):
+ "/"
)

@dataclass
class WandBConfig:
project: str = 'fms-hf-tuning' # experiment / project name
entity: Optional[str] = None


@dataclass
class TrackerConfigFactory:
Expand Down
27 changes: 25 additions & 2 deletions tuning/trackers/tracker_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,18 @@

# Information about all registered trackers
AIMSTACK_TRACKER = "aim"
WANDB_TRACKER = "wandb"
FILE_LOGGING_TRACKER = "file_logger"

AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER]
AVAILABLE_TRACKERS = [AIMSTACK_TRACKER, FILE_LOGGING_TRACKER, WANDB_TRACKER]


# Trackers which can be used
REGISTERED_TRACKERS = {}

# One time package check for list of external trackers.
_is_aim_available = _is_package_available("aim")

_is_wandb_available = _is_package_available("wandb")

def _get_tracker_class(T, C):
return {"tracker": T, "config": C}
Expand All @@ -59,10 +60,30 @@ def _register_aim_tracker():
"\t pip install aim"
)

def _register_wandb_tracker():
# pylint: disable=import-outside-toplevel
if _is_wandb_available:
# Local
from .wandb_tracker import WandBTracker
from tuning.config.tracker_configs import WandBConfig

WandbTracker = _get_tracker_class(WandBTracker, WandBConfig)

REGISTERED_TRACKERS[WANDB_TRACKER] = WandbTracker
logger.info("Registered wandb tracker")
else:
logger.info(
"Not registering WANDB due to unavailablity of package.\n"
"Please install wandb if you intend to use it.\n"
"\t pip install wandb"
)


def _is_tracker_installed(name):
if name == "aim":
return _is_aim_available
if name == "wandb":
return _is_wandb_available
return False


Expand All @@ -79,6 +100,8 @@ def _register_trackers():
logging.info("Registering trackers")
if AIMSTACK_TRACKER not in REGISTERED_TRACKERS:
_register_aim_tracker()
if WANDB_TRACKER not in REGISTERED_TRACKERS:
_register_wandb_tracker()
if FILE_LOGGING_TRACKER not in REGISTERED_TRACKERS:
_register_file_logging_tracker()

Expand Down
64 changes: 64 additions & 0 deletions tuning/trackers/wandb_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Copyright The FMS HF Tuning Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Standard
import json
import os

# Third Party
import wandb
from transformers.integrations import WandbCallback
from transformers.utils import logging

# Local
from .tracker import Tracker
from tuning.config.tracker_configs import WandBConfig

class WandBTracker(Tracker):
def __init__(self, tracker_config: WandBConfig):
"""Tracker which uses Wandb to collect and store metrics.
"""
super().__init__(name="aim", tracker_config=tracker_config)
self.logger = logging.get_logger("wandb_tracker")

def get_hf_callback(self):
"""Returns the WandBCallback object associated with this tracker.
"""
c = self.config
project = c.project
entity = c.entity

run = wandb.init(project=project, entity=entity)
WandbCallback = WandbCallback()

self.run = run
self.hf_callback = WandbCallback
return self.hf_callback

def _wandb_log(self, data, name):
self.run.log({name: data})

def track(self, metric, name, stage):
"""Track any additional metric with name under Aimstack tracker.
"""
if metric is None or name is None:
raise ValueError(
"wandb track function should not be called with None metric value or name"
)
self._wandb_log(metric, name)

def set_params(self, params, name="extra_params"):
"""Attach any extra params with the run information stored in Aimstack tracker.
"""
self.run.log(params)
Loading