From be8ebcfe7d2d7d9c6abd408c6388955568ff70a3 Mon Sep 17 00:00:00 2001
From: Hanlin Tang <hanlin@mosaicml.com>
Date: Tue, 3 May 2022 16:18:55 -0700
Subject: [PATCH] Add MLPerf logging (#831)

Adds an experimental logger to create MLperf compliant submission files
---
 Makefile                                |   2 +-
 composer/callbacks/__init__.py          |   6 +-
 composer/callbacks/callback_hparams.py  |  62 +++-
 composer/callbacks/mlperf.py            | 378 ++++++++++++++++++++++++
 composer/core/state.py                  |  24 +-
 composer/trainer/trainer.py             |   5 +-
 composer/trainer/trainer_hparams.py     |   3 +-
 setup.py                                |   6 +
 tests/callbacks/test_callbacks.py       |   2 +
 tests/callbacks/test_mlperf_callback.py | 171 +++++++++++
 tests/trainer/test_trainer.py           |  11 +-
 11 files changed, 657 insertions(+), 13 deletions(-)
 create mode 100644 composer/callbacks/mlperf.py
 create mode 100644 tests/callbacks/test_mlperf_callback.py

diff --git a/Makefile b/Makefile
index fef48ca531..ad1be9b8e4 100644
--- a/Makefile
+++ b/Makefile
@@ -27,4 +27,4 @@ test-dist-gpu:
 clean-notebooks:
 	$(PYTHON) scripts/clean_notebooks.py -i notebooks/*.ipynb
 
-.PHONY: test test-gpu test-dist test-dist-gpu lint style clean-notebooks
+.PHONY: test test-gpu test-dist test-dist-gpu clean-notebooks
diff --git a/composer/callbacks/__init__.py b/composer/callbacks/__init__.py
index e6508a8f89..eaf2bee41f 100644
--- a/composer/callbacks/__init__.py
+++ b/composer/callbacks/__init__.py
@@ -6,11 +6,13 @@
 examples for writing your own callbacks at the :class:`~composer.core.callback.Callback` base class.
 """
 from composer.callbacks.callback_hparams import (CallbackHparams, CheckpointSaverHparams, GradMonitorHparams,
-                                                 LRMonitorHparams, MemoryMonitorHparams, SpeedMonitorHparams)
+                                                 LRMonitorHparams, MemoryMonitorHparams, MLPerfCallbackHparams,
+                                                 SpeedMonitorHparams)
 from composer.callbacks.checkpoint_saver import CheckpointSaver
 from composer.callbacks.grad_monitor import GradMonitor
 from composer.callbacks.lr_monitor import LRMonitor
 from composer.callbacks.memory_monitor import MemoryMonitor
+from composer.callbacks.mlperf import MLPerfCallback
 from composer.callbacks.speed_monitor import SpeedMonitor
 
 __all__ = [
@@ -19,6 +21,7 @@
     "MemoryMonitor",
     "SpeedMonitor",
     "CheckpointSaver",
+    "MLPerfCallback",
     # hparams objects
     "CallbackHparams",
     "CheckpointSaverHparams",
@@ -26,4 +29,5 @@
     "LRMonitorHparams",
     "MemoryMonitorHparams",
     "SpeedMonitorHparams",
+    "MLPerfCallbackHparams",
 ]
diff --git a/composer/callbacks/callback_hparams.py b/composer/callbacks/callback_hparams.py
index 970633a4c3..5742fe6dff 100644
--- a/composer/callbacks/callback_hparams.py
+++ b/composer/callbacks/callback_hparams.py
@@ -5,7 +5,7 @@
 
 import abc
 import textwrap
-from dataclasses import dataclass
+from dataclasses import asdict, dataclass
 from typing import Optional
 
 import yahp as hp
@@ -14,6 +14,7 @@
 from composer.callbacks.grad_monitor import GradMonitor
 from composer.callbacks.lr_monitor import LRMonitor
 from composer.callbacks.memory_monitor import MemoryMonitor
+from composer.callbacks.mlperf import MLPerfCallback
 from composer.callbacks.speed_monitor import SpeedMonitor
 from composer.core.callback import Callback
 from composer.core.time import Time
@@ -48,7 +49,7 @@ class GradMonitorHparams(CallbackHparams):
     """:class:`~.GradMonitor` hyperparamters.
 
     Args:
-        log_layer_grad_norms (bool, optional): 
+        log_layer_grad_norms (bool, optional):
             See :class:`~.GradMonitor` for documentation.
     """
 
@@ -119,10 +120,65 @@ def initialize_object(self) -> SpeedMonitor:
         return SpeedMonitor(window_size=self.window_size)
 
 
+@dataclass
+class MLPerfCallbackHparams(CallbackHparams):
+    """:class:`~.MLPerfCallback` hyperparameters.
+
+    Args:
+        root_folder (str): The root submission folder
+        index (int): The repetition index of this run. The filename created will be
+            ``result_[index].txt``.
+        benchmark (str, optional): Benchmark name. Currently only ``resnet`` supported.
+        target (float, optional): The target metric before the mllogger marks the stop
+            of the timing run. Default: ``0.759`` (resnet benchmark).
+        division (str, optional): Division of submission. Currently only ``open`` division supported.
+        metric_name (str, optional): name of the metric to compare against the target. Default: ``Accuracy``.
+        metric_label (str, optional): label name. The metric will be accessed via ``state.current_metrics[metric_label][metric_name]``.
+        submitter (str, optional): Submitting organization. Default: MosaicML.
+        system_name (str, optional): Name of the system (e.g. 8xA100_composer). If
+            not provided, system name will default to ``[world_size]x[device_name]_composer``,
+            e.g. ``8xNVIDIA_A100_80GB_composer``.
+        status (str, optional): Submission status. One of (onprem, cloud, or preview).
+            Default: ``"onprem"``.
+        cache_clear_cmd (str, optional): Command to invoke during the cache clear. This callback
+            will call ``subprocess(cache_clear_cmd)``. Default is disabled (None)
+
+    """
+
+    root_folder: str = hp.required("The root submission folder.")
+    index: int = hp.required("The repetition index of this run.")
+    benchmark: str = hp.optional("Benchmark name. Default: resnet", default="resnet")
+    target: float = hp.optional("The target metric before mllogger marks run_stop. Default: 0.759 (resnet)",
+                                default=0.759)
+    division: Optional[str] = hp.optional(
+        "Division of submission. Currently only open division"
+        "is supported. Default: open", default="open")
+    metric_name: str = hp.optional('name of the metric to compare against the target. Default: Accuracy',
+                                   default='Accuracy')
+    metric_label: str = hp.optional(
+        'label name. The metric will be accessed via state.current_metrics[metric_label][metric_name]. Default: eval',
+        default='eval')
+    submitter: str = hp.optional("Submitting organization. Default: MosaicML", default='MosaicML')
+    system_name: Optional[str] = hp.optional("Name of the system, defaults to [world_size]x[device_name]", default=None)
+    status: str = hp.optional("Submission status. Default: onprem", default="onprem")
+    cache_clear_cmd: Optional[str] = hp.optional(
+        "Command to invoke during the cache clear. This callback will call subprocess(cache_clear_cmd). Default: Disabled.",
+        default=None,
+    )
+
+    def initialize_object(self) -> MLPerfCallback:
+        """Initialize the MLPerf Callback.
+
+        Returns:
+            MLPerfCallback: An instance of :class:`~.MLPerfCallback`
+        """
+        return MLPerfCallback(**asdict(self))
+
+
 @dataclass
 class CheckpointSaverHparams(CallbackHparams):
     """:class:`~.CheckpointSaver` hyperparameters.
-    
+
     Args:
         save_folder (str, optional): See :class:`~.CheckpointSaver`.
         filename (str, optional): See :class:`~.CheckpointSaver`.
diff --git a/composer/callbacks/mlperf.py b/composer/callbacks/mlperf.py
new file mode 100644
index 0000000000..590eb285a6
--- /dev/null
+++ b/composer/callbacks/mlperf.py
@@ -0,0 +1,378 @@
+import json
+import logging
+import os
+import platform
+import subprocess
+import sys
+import warnings
+from typing import Any, Dict, List, Optional, Sized
+
+import torch
+from torch.utils.data import DataLoader
+
+import composer
+from composer.core import State
+from composer.core.callback import Callback
+from composer.loggers import Logger
+from composer.loggers.logger import LogLevel
+from composer.utils import dist
+
+try:
+    import cpuinfo
+    import psutil
+    from mlperf_logging import mllog  # type: ignore (no pypi for ci)
+    from mlperf_logging.mllog import constants  # type: ignore (no pypi for ci)
+
+    mlperf_available = True
+except ImportError:
+    mlperf_available = False
+
+# this callback only supports the following options:
+BENCHMARKS = ("resnet",)
+DIVISIONS = ("open",)
+STATUS = ("onprem", "cloud", "preview")
+
+
+def rank_zero() -> bool:
+    return dist.get_global_rank() == 0
+
+
+def require_mlperf_logging():
+    if not mlperf_available:
+        raise ImportError("""Please install with `pip install mosaicml[mlperf]` and also
+                          install the logging library from: https://github.com/mlcommons/logging""")
+
+
+class MLPerfCallback(Callback):
+    """Creates a compliant results file for MLPerf Training benchmark.
+
+    A submission folder structure will be created with the ``root_folder``
+    as the base and the following directories::
+
+        root_folder/
+            results/
+                [system_name]/
+                    [benchmark]/
+                        results_0.txt
+                        results_1.txt
+                        ...
+            systems/
+                [system_name].json
+
+    A required systems description will be automatically generated,
+    and best effort made to populate the fields, but should be manually
+    checked prior to submission.
+
+    Currently, only open division submissions are supported with this Callback.
+
+    Example:
+
+    .. code-block:: python
+
+        from composer.callbacks import MLPerfCallback
+
+        callback = MLPerfCallback(
+            root_folder='/submission',
+            index=0,
+            metric_name='Accuracy',
+            metric_label='eval',
+            target='0.759',
+        )
+
+    During training, the metric found in ``state.current_metrics[metric_label][metric_name]``
+    will be compared against the target criterion.
+
+    .. note::
+
+        This is currently an experimental logger, that has not been used (yet)
+        to submit an actual result to MLPerf. Please use with caution.
+
+    .. note::
+
+        MLPerf submissions require clearing the system cache prior to any training run.
+        By default, this callback does not clear the cache, as that is a system specific
+        operation. To enable cache clearing, and thus pass the mlperf compliance checker,
+        provide a ``cache_clear_cmd`` that will be executed with ``os.system``.
+
+    Args:
+        root_folder (str): The root submission folder
+        index (int): The repetition index of this run. The filename created will be
+            ``result_[index].txt``.
+        benchmark (str, optional): Benchmark name. Currently only ``resnet`` supported.
+        target (float, optional): The target metric before the mllogger marks the stop
+            of the timing run. Default: ``0.759`` (resnet benchmark).
+        division (str, optional): Division of submission. Currently only ``open`` division supported.
+        metric_name (str, optional): name of the metric to compare against the target. Default: ``Accuracy``.
+        metric_label (str, optional): label name. The metric will be accessed via ``state.current_metrics[metric_label][metric_name]``.
+        submitter (str, optional): Submitting organization. Default: MosaicML.
+        system_name (str, optional): Name of the system (e.g. 8xA100_composer). If
+            not provided, system name will default to ``[world_size]x[device_name]_composer``,
+            e.g. ``8xNVIDIA_A100_80GB_composer``.
+        status (str, optional): Submission status. One of (onprem, cloud, or preview).
+            Default: ``"onprem"``.
+        cache_clear_cmd (str, optional): Command to invoke during the cache clear. This callback
+            will call ``os.system(cache_clear_cmd)``. Default is disabled (None)
+    """
+
+    def __init__(
+        self,
+        root_folder: str,
+        index: int,
+        benchmark: str = 'resnet',
+        target: float = 0.759,
+        division: str = 'open',
+        metric_name: str = 'Accuracy',
+        metric_label: str = 'eval',
+        submitter: str = "MosaicML",
+        system_name: Optional[str] = None,
+        status: str = "onprem",
+        cache_clear_cmd: Optional[List[str]] = None,
+    ) -> None:
+
+        require_mlperf_logging()
+
+        if benchmark not in BENCHMARKS:
+            raise ValueError(f"benchmark: {benchmark} must be one of {BENCHMARKS}")
+        if division not in DIVISIONS:
+            raise ValueError(f"division: {division} must be one of {DIVISIONS}")
+        if status not in STATUS:
+            raise ValueError(f"status: {status} must be one of {STATUS}")
+
+        self.mllogger = mllog.get_mllogger()
+        self.target = target
+        self.benchmark = benchmark
+        self.target = target
+        self.division = division
+        self.submitter = submitter
+        self.status = status
+        self.cache_clear_cmd = cache_clear_cmd
+        self.root_folder = root_folder
+        self.metric_name = metric_name
+        self.metric_label = metric_label
+        self._file_handler = None
+
+        self.system_desc = get_system_description(submitter, division, status, system_name)
+        if system_name is None:
+            system_name = self.system_desc['system_name']
+        self.system_name = system_name
+
+        # file paths to save the systems file, results file
+        self.systems_path = os.path.join(root_folder, 'systems', f'{system_name}.json')
+        self.filename = os.path.join(root_folder, 'results', system_name, benchmark, f'result_{index}.txt')
+
+        # upload names for object store logging
+        self.upload_name = '{run_name}' + f'/results/{system_name}/{benchmark}/result_{index}.txt'
+        self.system_desc_upload_name = '{run_name}' + f'/systems/{system_name}.json'
+
+        self.success = False
+
+    def init(self, state: State, logger: Logger) -> None:
+
+        # setup here requies access to rank, which is only available after
+        # the trainer is initialized
+        if dist.get_local_rank() == 0:
+            self._create_submission_folders(self.root_folder, self.system_name, self.benchmark)
+            with open(self.systems_path, 'w') as f:
+                json.dump(self.system_desc, f, indent=4)
+
+        dist.barrier()
+
+        if os.path.exists(self.filename):
+            raise FileExistsError(f'{self.filename} already exists.')
+
+        self._file_handler = logging.FileHandler(self.filename)
+        self._file_handler.setLevel(logging.INFO)
+        self.mllogger.logger.addHandler(self._file_handler)
+
+        if self.cache_clear_cmd is not None:
+            subprocess.run(self.cache_clear_cmd, check=True, text=True)
+            self.mllogger.start(key=mllog.constants.CACHE_CLEAR)
+        else:
+            warnings.warn("cache_clear_cmd was not provided. For a valid submission, please provide the command.")
+
+        self.mllogger.start(key=mllog.constants.INIT_START)
+
+        if rank_zero():
+            self._log_dict({
+                constants.SUBMISSION_BENCHMARK: self.benchmark,
+                constants.SUBMISSION_DIVISION: self.division,
+                constants.SUBMISSION_ORG: self.submitter,
+                constants.SUBMISSION_PLATFORM: self.system_name,
+                constants.SUBMISSION_STATUS: self.status,
+            })
+
+            # optionally, upload the system description file
+            logger.file_artifact(LogLevel.FIT, self.system_desc_upload_name, self.systems_path)
+
+    def _create_submission_folders(self, root_folder: str, system_name: str, benchmark: str):
+        os.makedirs(root_folder, exist_ok=True)
+
+        results_folder = os.path.join(root_folder, 'results')
+        log_folder = os.path.join(root_folder, 'results', system_name)
+        benchmark_folder = os.path.join(log_folder, benchmark)
+        systems_folder = os.path.join(root_folder, 'systems')
+
+        os.makedirs(results_folder, exist_ok=True)
+        os.makedirs(log_folder, exist_ok=True)
+        os.makedirs(benchmark_folder, exist_ok=True)
+        os.makedirs(systems_folder, exist_ok=True)
+
+    def _log_dict(self, data: Dict[str, Any]):
+        for key, value in data.items():
+            self.mllogger.event(key=key, value=value)
+
+    def _get_accuracy(self, state: State):
+        if self.metric_name not in state.current_metrics[self.metric_label]:
+            raise ValueError('Accuracy must be a validation metric.')
+        return state.current_metrics[self.metric_label][self.metric_name]
+
+    def fit_start(self, state: State, logger: Logger) -> None:
+        if rank_zero():
+
+            if not isinstance(state.train_dataloader, DataLoader):
+                raise TypeError("train dataloader must be a torch dataloader")
+            if not isinstance(state.evaluators[0].dataloader.dataloader, DataLoader):
+                raise TypeError("eval dataset must be a torch dataloader.")
+            if state.train_dataloader.batch_size is None:
+                raise ValueError("Batch size is required to be set for dataloader.")
+            if len(state.evaluators) > 1:
+                raise ValueError("Only one evaluator is supported for the MLPerfCallback.")
+            if not isinstance(state.train_dataloader.dataset, Sized):
+                raise TypeError("Train dataset must have __len__ property")
+            if not isinstance(state.evaluators[0].dataloader.dataloader.dataset, Sized):
+                raise TypeError("The eval dataset must have __len__ property")
+
+            self._log_dict({
+                constants.SEED: state.seed,
+                constants.GLOBAL_BATCH_SIZE: state.train_dataloader.batch_size * dist.get_world_size(),
+                constants.GRADIENT_ACCUMULATION_STEPS: state.grad_accum,
+                constants.TRAIN_SAMPLES: len(state.train_dataloader.dataset),
+                constants.EVAL_SAMPLES: len(state.evaluators[0].dataloader.dataloader.dataset)
+            })
+
+        self.mllogger.event(key=constants.INIT_STOP)
+
+        dist.barrier()
+
+        if rank_zero():
+            self.mllogger.event(key=constants.RUN_START)
+
+    def epoch_start(self, state: State, logger: Logger) -> None:
+        if rank_zero():
+            self.mllogger.event(key=constants.EPOCH_START, metadata={'epoch_num': state.timer.epoch.value})
+            self.mllogger.event(key=constants.BLOCK_START,
+                                metadata={
+                                    'first_epoch_num': state.timer.epoch.value,
+                                    'epoch_count': 1
+                                })
+
+    def epoch_end(self, state: State, logger: Logger) -> None:
+        if rank_zero():
+            self.mllogger.event(key=constants.EPOCH_STOP, metadata={'epoch_num': state.timer.epoch.value})
+            logger.file_artifact(LogLevel.FIT, artifact_name=self.upload_name, file_path=self.filename)
+
+    def eval_start(self, state: State, logger: Logger) -> None:
+        if rank_zero():
+            self.mllogger.event(key=constants.EVAL_START, metadata={'epoch_num': state.timer.epoch.value})
+
+    def eval_end(self, state: State, logger: Logger) -> None:
+        if rank_zero():
+            accuracy = self._get_accuracy(state)
+
+            self.mllogger.event(key=constants.EVAL_STOP, metadata={'epoch_num': state.timer.epoch.value})
+            self.mllogger.event(key=constants.EVAL_ACCURACY,
+                                value=accuracy,
+                                metadata={'epoch_num': state.timer.epoch.value})
+            self.mllogger.event(key=constants.BLOCK_STOP, metadata={'first_epoch_num': state.timer.epoch.value})
+
+            if accuracy > self.target and not self.success:
+                self.mllogger.event(key=constants.RUN_STOP, metadata={"status": "success"})
+                self.mllogger.logger.removeHandler(self._file_handler)
+                self.success = True  # only log once
+
+    def close(self, state: State, logger: Logger) -> None:
+        if self._file_handler is not None:
+            self._file_handler.close()
+
+
+def get_system_description(
+    submitter: str,
+    division: str,
+    status: str,
+    system_name: Optional[str] = None,
+) -> Dict[str, str]:
+    """Generates a valid system description.
+
+    Make a best effort to auto-populate some of the fields, but should
+    be manually checked prior to submission. The system name is
+    auto-generated as "[world_size]x[device_name]_composer", e.g.
+    "8xNVIDIA_A100_80GB_composer".
+
+    Args:
+        submitter (str): Name of the submitter organization
+        division (str): Submission division (open, closed)
+        status (str): system status (cloud, onprem, preview)
+
+    Returns:
+        system description as a dictionary
+    """
+    is_cuda = torch.cuda.is_available()
+    cpu_info = cpuinfo.get_cpu_info()
+
+    system_desc = {
+        "submitter": submitter,
+        "division": division,
+        "status": status,
+        "number_of_nodes": dist.get_world_size() / dist.get_local_world_size(),
+        "host_processors_per_node": "",
+        "host_processor_model_name": str(cpu_info.get('brand_raw', "CPU")),
+        "host_processor_core_count": str(psutil.cpu_count(logical=False)),
+        "host_processor_vcpu_count": "",
+        "host_processor_frequency": "",
+        "host_processor_caches": "",
+        "host_processor_interconnect": "",
+        "host_memory_capacity": "",
+        "host_storage_type": "",
+        "host_storage_capacity": "",
+        "host_networking": "",
+        "host_networking_topology": "",
+        "host_memory_configuration": "",
+        "accelerators_per_node": str(dist.get_local_world_size()) if is_cuda else "0",
+        "accelerator_model_name": str(torch.cuda.get_device_name(None)) if is_cuda else "",
+        "accelerator_host_interconnect": "",
+        "accelerator_frequency": "",
+        "accelerator_on-chip_memories": "",
+        "accelerator_memory_configuration": "",
+        "accelerator_memory_capacity": "",
+        "accelerator_interconnect": "",
+        "accelerator_interconnect_topology": "",
+        "cooling": "",
+        "hw_notes": "",
+        "framework":
+            f"PyTorch v{torch.__version__} and MosaicML composer v{composer.__version__}",  # type: ignore (third-party missing stub)
+        "other_software_stack": {
+            "cuda_version": torch.version.cuda if is_cuda else "",  # type: ignore (third-party missing stub)
+            "composer_version": composer.__version__,
+            "python_version": sys.version,
+        },
+        "operating_system": f"{platform.system()} {platform.release()}",
+        "sw_notes": "",
+    }
+
+    if system_desc['number_of_nodes'] != 1:
+        warnings.warn("Number of nodes > 1 not tested, proceed with caution.")
+
+    if system_name is None:
+        world_size = dist.get_world_size()
+        if is_cuda:
+            device_name = system_desc['accelerator_model_name']
+        else:
+            device_name = system_desc['host_processor_model_name']
+
+        device_name = device_name.replace(' ', '_')
+        system_name = f"{world_size}x{device_name}_composer"
+
+    # default to system name as "[world_size]x[device_name]"
+    # e.g. 8xNVIDIA_A100_80GB
+    system_desc['system_name'] = system_name
+    return system_desc
diff --git a/composer/core/state.py b/composer/core/state.py
index a6784b40bf..99d054234d 100644
--- a/composer/core/state.py
+++ b/composer/core/state.py
@@ -24,6 +24,7 @@
     import composer.core.types as types
     from composer.core.algorithm import Algorithm
     from composer.core.callback import Callback
+    from composer.core.evaluator import Evaluator
     from composer.profiler import Profiler
 
 __all__ = ["State"]
@@ -77,7 +78,9 @@ class State(Serializable):
             ``rank_zero_seed + dist.get_global_rank()``.
         grad_accum (int, optional): The number of gradient accumulation steps to use. With this argument, micro batch size for
             each device becomes ``microbatch_size = train_batch_size / (num_devices * grad_accum)``.
-        dataloader (Iterable, optional): The active DataLoader.
+        train_dataloader (types.DataLoader, optional): Dataloader used for training
+        evaluators (Evalutor | Evaluators, optional): :class:`.Evaluator` used for evaluation.
+        dataloader (types.DataLoader, optional): The active DataLoader.
         dataloader_len (int | Time[int], optional): The number of batches per dataloader iteration (e.g. epoch).
             The trainer will yield the first ``dataloader_len`` batches per iteration. If ``-1`` (the default),
             the entire dataloader will be iterated over.
@@ -173,6 +176,7 @@ class State(Serializable):
     _dataloader_label: Optional[str]
     _dataloader_len: Optional[Time[int]]
     _max_duration: Optional[Time[int]]
+
     batch: types.Batch
     batch_num_samples: int
     batch_num_tokens: int
@@ -193,6 +197,13 @@ def __init__(
 
         # data configurations
         grad_accum: int = 1,
+
+        # dataloaders
+        train_dataloader: Optional[Iterable] = None,
+        evaluators: Optional[Union[Evaluator, Sequence[Evaluator]]] = None,
+
+        # these track the current 'active' dataloader
+        # depending on train, eval, or others
         dataloader: Optional[Iterable] = None,
         dataloader_label: Optional[str] = None,
         dataloader_len: Union[int, Time[int]] = -1,
@@ -217,6 +228,9 @@ def __init__(
         self.set_dataloader(dataloader, dataloader_label, dataloader_len)
         self.max_duration = max_duration
 
+        self.train_dataloader = train_dataloader
+        self._evaluators = list(ensure_tuple(evaluators))
+
         self.timer = Timer()
         self._precision = Precision(precision)
 
@@ -317,6 +331,14 @@ def algorithms(self):
     def algorithms(self, algorithms: Sequence[Algorithm]):
         self._algorithms[:] = algorithms
 
+    @property
+    def evaluators(self):
+        return self._evaluators
+
+    @evaluators.setter
+    def evaluators(self, evaluators: Union[Evaluator, Sequence[Evaluator]]):
+        self._evaluators[:] = list(ensure_tuple(evaluators))
+
     def state_dict(self) -> Dict[str, Any]:
         """Returns the state as a :class:`dict`."""
         state_dict = {}
diff --git a/composer/trainer/trainer.py b/composer/trainer/trainer.py
index 342468c511..02210b2ed8 100644
--- a/composer/trainer/trainer.py
+++ b/composer/trainer/trainer.py
@@ -892,6 +892,7 @@ def __init__(
         # After running Event.INIT, then set the "optional" elements of state that could be passed in on FIT instead of INIT
         # Setting these attributes here ensures that algorithms do not depend on unavailable attributes during Event.INIT
         self.state.set_dataloader(train_dataloader.dataloader, 'train', train_subset_num_batches)
+        self.state.train_dataloader = train_dataloader.dataloader
         self.state.max_duration = max_duration
         self.logger.data_fit({"rank_zero_seed": rank_zero_seed})
 
@@ -924,6 +925,8 @@ def __init__(
             if eval_interval != 1:
                 warnings.warn("Specifying `eval_interval` without an `eval_dataloader` has no effect.")
 
+        self.state.evaluators = self.evaluators
+
         # place the state, model in the proper devices, and initialize from a checkpoint if provided
         if self.deepspeed_enabled:
             try:
@@ -1292,7 +1295,7 @@ def _train_batch(self, use_grad_scaling: bool):
             except RuntimeError as e:
                 if self._is_cuda_oom(e):
                     log.debug(
-                        textwrap.dedent(f"""Rank {dist.get_global_rank()} OOM'd. 
+                        textwrap.dedent(f"""Rank {dist.get_global_rank()} OOM'd.
                         grad_accum will be increased prior to reattempting training on the current batch."""))
                     should_handle_cuda_oom = 1
                 elif "Timed out" in str(e):
diff --git a/composer/trainer/trainer_hparams.py b/composer/trainer/trainer_hparams.py
index b5f190aff6..ac710a2995 100755
--- a/composer/trainer/trainer_hparams.py
+++ b/composer/trainer/trainer_hparams.py
@@ -17,7 +17,7 @@
 import composer
 from composer.algorithms import AlgorithmHparams, get_algorithm_registry
 from composer.callbacks import (CallbackHparams, GradMonitorHparams, LRMonitorHparams, MemoryMonitorHparams,
-                                SpeedMonitorHparams)
+                                MLPerfCallbackHparams, SpeedMonitorHparams)
 from composer.core import Precision
 from composer.core.types import JSON
 from composer.datasets import DataLoaderHparams, DatasetHparams
@@ -95,6 +95,7 @@
     "lr_monitor": LRMonitorHparams,
     "grad_monitor": GradMonitorHparams,
     "memory_monitor": MemoryMonitorHparams,
+    "mlperf": MLPerfCallbackHparams,
 }
 
 device_registry = {
diff --git a/setup.py b/setup.py
index 38f51511da..ce0914e7f8 100644
--- a/setup.py
+++ b/setup.py
@@ -151,6 +151,12 @@ def package_files(prefix: str, directory: str, extension: str):
     "wurlitzer>=3.0.2,<4",
 ]
 
+extra_deps["mlperf"] = [
+    # TODO: use pip when available: https://github.com/mlcommons/logging/issues/218
+    # "mlperf_logging @ git+https://github.com/mlperf/logging.git",
+    "py-cpuinfo>=8.0.0,<9",
+]
+
 extra_deps["all"] = set(dep for deps in extra_deps.values() for dep in deps)
 
 composer_data_files = ["py.typed"]
diff --git a/tests/callbacks/test_callbacks.py b/tests/callbacks/test_callbacks.py
index 9eba45b1c2..579f54c412 100644
--- a/tests/callbacks/test_callbacks.py
+++ b/tests/callbacks/test_callbacks.py
@@ -8,6 +8,7 @@
 import composer.callbacks
 import composer.loggers
 import composer.profiler
+from composer.callbacks.mlperf import MLPerfCallback
 from composer.core import Event
 from composer.core.callback import Callback
 from composer.core.engine import Engine
@@ -58,6 +59,7 @@ def _get_callback_factories() -> List[Callable[..., Callback]]:
     callback_factories.extend(
         x for x in vars(composer.profiler).values() if isinstance(x, type) and issubclass(x, Callback))
     callback_factories.remove(ObjectStoreLogger)
+    callback_factories.remove(MLPerfCallback)
     callback_factories.append(lambda: ObjectStoreLogger(
         use_procs=False,
         num_concurrent_uploads=1,
diff --git a/tests/callbacks/test_mlperf_callback.py b/tests/callbacks/test_mlperf_callback.py
new file mode 100644
index 0000000000..5042e9a381
--- /dev/null
+++ b/tests/callbacks/test_mlperf_callback.py
@@ -0,0 +1,171 @@
+# ignore third-party missing imports due to the mlperf logger not pip-installable
+# pyright: reportMissingImports=none
+
+import logging
+from unittest.mock import Mock
+
+import numpy as np
+import pytest
+from torch.utils.data import DataLoader
+
+from composer import State, Trainer
+from composer.callbacks import MLPerfCallback
+from composer.utils import dist
+from tests.common import RandomClassificationDataset, SimpleModel
+
+
+def rank_zero() -> bool:
+    return dist.get_global_rank() == 0
+
+
+@pytest.fixture(autouse=True)
+def importor_skip_mlperf_logging():
+    pytest.importorskip("mlperf_logging")
+
+
+class MockMLLogger:
+    """Mocks the MLPerf Logger interface."""
+
+    def __init__(self) -> None:
+        self.logs = []
+        self.logger = Mock()
+
+    def event(self, key, metadata, value=None):
+        self.logs.append({'key': key, 'value': value, 'metadata': metadata})
+
+
+class TestMLPerfCallbackEvents:
+
+    @pytest.fixture
+    def mlperf_callback(self, monkeypatch, tmpdir) -> MLPerfCallback:
+        """Returns a callback with the MockMLLogger patched."""
+        callback = MLPerfCallback(tmpdir, 0)
+        monkeypatch.setattr(callback, 'mllogger', MockMLLogger())
+        return callback
+
+    @pytest.fixture
+    def mock_state(self):
+        """Mocks a state at epoch 1 with Accuracy 0.99."""
+        current_metrics = {'eval': {'Accuracy': 0.99}}
+
+        state = Mock()
+        state.current_metrics = current_metrics
+        state.timer.epoch.value = 1
+
+        return state
+
+    @pytest.mark.timeout(5)
+    def test_eval_start(self, mlperf_callback, mock_state):
+        mlperf_callback.eval_start(mock_state, Mock())
+
+        if not rank_zero():
+            assert mlperf_callback.mllogger.logs == []
+            return
+
+        assert mlperf_callback.mllogger.logs == [{'key': 'eval_start', 'value': None, 'metadata': {'epoch_num': 1}}]
+
+    @pytest.mark.timeout(5)
+    def test_eval_end(self, mlperf_callback, mock_state):
+        mlperf_callback.eval_end(mock_state, Mock())
+
+        if not rank_zero():
+            assert mlperf_callback.success == False
+            assert mlperf_callback.mllogger.logs == []
+            return
+
+        assert mlperf_callback.success == True
+        assert mlperf_callback.mllogger.logs[-1] == {
+            'key': 'run_stop',
+            'value': None,
+            'metadata': {
+                'status': 'success'
+            }
+        }
+
+
+class TestWithMLPerfChecker:
+    """Ensures that the logs created by the MLPerfCallback pass the official package checker."""
+
+    @pytest.mark.timeout(15)
+    def test_mlperf_callback_passes(self, tmpdir, monkeypatch):
+
+        def mock_accuracy(self, state: State):
+            if state.timer.epoch >= 2:
+                return 0.99
+            else:
+                return 0.01
+
+        monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)
+
+        self.generate_submission(tmpdir)
+
+        if rank_zero():
+            self.run_mlperf_checker(tmpdir, monkeypatch)
+
+    @pytest.mark.timeout(15)
+    def test_mlperf_callback_fails(self, tmpdir, monkeypatch):
+
+        def mock_accuracy(self, state: State):
+            return 0.01
+
+        monkeypatch.setattr(MLPerfCallback, '_get_accuracy', mock_accuracy)
+
+        self.generate_submission(tmpdir)
+        with pytest.raises(ValueError, match='MLPerf checker failed'):
+            self.run_mlperf_checker(tmpdir, monkeypatch)
+
+    def generate_submission(self, directory):
+        """Generates submission files by training the benchark n=5 times."""
+
+        for run in range(5):
+
+            mlperf_callback = MLPerfCallback(
+                root_folder=directory,
+                index=run,
+                cache_clear_cmd=['sleep', '0.1'],
+            )
+
+            trainer = Trainer(
+                model=SimpleModel(),
+                train_dataloader=DataLoader(
+                    dataset=RandomClassificationDataset(),
+                    batch_size=4,
+                    shuffle=False,
+                ),
+                eval_dataloader=DataLoader(
+                    dataset=RandomClassificationDataset(),
+                    shuffle=False,
+                ),
+                max_duration="3ep",
+                deterministic_mode=True,
+                progress_bar=False,
+                log_to_console=False,
+                loggers=[],
+                callbacks=[mlperf_callback],
+                seed=np.random.randint(low=2048),
+            )
+
+            trainer.fit()
+
+    def run_mlperf_checker(self, directory, monkeypatch):
+        """Runs the MLPerf package checker and fails on any errors."""
+
+        # monkeypatch the logging so that logging.error raises Exception
+        def fail_on_error(msg, *args, **kwargs):
+            print(msg.format(*args))
+            raise ValueError('MLPerf checker failed, see logs.')
+
+        monkeypatch.setattr(logging, "error", fail_on_error)
+
+        from mlperf_logging.package_checker.package_checker import check_training_package
+
+        check_training_package(
+            folder=directory,
+            usage="training",
+            ruleset="1.1.0",
+            werror=True,
+            quiet=False,
+            rcp_bypass=False,
+            rcp_bert_train_samples=False,
+            log_output="package_checker.log",
+        )
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index 3fb730ca0b..29f64e9f65 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -330,10 +330,8 @@ def test_data_not_augmented(self, config):
 
 @pytest.mark.timeout(15)
 class TestTrainerAssets:
-    """
-    The below is a catch-all test that runs the Trainer
-    with each algorithm, callback, and loggers. Success
-    is defined as a successful training run.
+    """The below is a catch-all test that runs the Trainer with each algorithm, callback, and loggers. Success is
+    defined as a successful training run.
 
     This should eventually be replaced by functional
     tests for each object, in situ of our trainer.
@@ -368,7 +366,10 @@ def config(self, rank_zero_seed: int, request):
 
     @pytest.fixture(params=callback_registry.items(), ids=tuple(callback_registry.keys()))
     def callback(self, request):
-        _, hparams = request.param
+        name, hparams = request.param
+
+        if name == 'mlperf':
+            pytest.skip('mlperf callback tested separately.')
 
         callback = hparams().initialize_object()