From f53c26b71bb5ee5f7d7b5d181c9914b74de5ea09 Mon Sep 17 00:00:00 2001 From: Alex Ge Date: Wed, 20 Apr 2022 16:57:56 +0100 Subject: [PATCH] init add of wandb ai writer --- ml-agents/mlagents/plugins/stats_writer.py | 4 +++- ml-agents/mlagents/trainers/stats.py | 24 +++++++++++++++++++ .../mlagents/trainers/tests/test_stats.py | 18 ++++++++++++++ 3 files changed, 45 insertions(+), 1 deletion(-) diff --git a/ml-agents/mlagents/plugins/stats_writer.py b/ml-agents/mlagents/plugins/stats_writer.py index 17acefd32e..59bc4d4616 100644 --- a/ml-agents/mlagents/plugins/stats_writer.py +++ b/ml-agents/mlagents/plugins/stats_writer.py @@ -13,7 +13,7 @@ from mlagents_envs import logging_util from mlagents.plugins import ML_AGENTS_STATS_WRITER from mlagents.trainers.settings import RunOptions -from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter +from mlagents.trainers.stats import TensorboardWriter, GaugeWriter, ConsoleWriter, WandbWriter logger = logging_util.get_logger(__name__) @@ -25,6 +25,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: * A TensorboardWriter to write information to TensorBoard * A GaugeWriter to record our internal stats * A ConsoleWriter to output to stdout. + * A Wandb.AI Writer """ checkpoint_settings = run_options.checkpoint_settings return [ @@ -35,6 +36,7 @@ def get_default_stats_writers(run_options: RunOptions) -> List[StatsWriter]: ), GaugeWriter(), ConsoleWriter(), + WandbWriter() ] diff --git a/ml-agents/mlagents/trainers/stats.py b/ml-agents/mlagents/trainers/stats.py index 78dc33893e..bdde74a3ae 100644 --- a/ml-agents/mlagents/trainers/stats.py +++ b/ml-agents/mlagents/trainers/stats.py @@ -6,6 +6,7 @@ import os import time from threading import RLock +import wandb from mlagents_envs.side_channel.stats_side_channel import StatsAggregationMethod @@ -286,6 +287,29 @@ def add_property( self.summary_writers[category].flush() +class WandbWriter(StatsWriter): + def __init__( + self, + config: dict + ): + """ + A Weights and Biases Wrapper that will add stats to your wandb.ai board. + """ + wandb.init(reinit=True, + config=config) + + def write_stats( + self, + category : str, + values : dict, + step : int + ) -> None: + """ + Write some stats for a given category and step + """ + wandb.log({category : values}, step=step) + + class StatsReporter: writers: List[StatsWriter] = [] stats_dict: Dict[str, Dict[str, List]] = defaultdict(lambda: defaultdict(list)) diff --git a/ml-agents/mlagents/trainers/tests/test_stats.py b/ml-agents/mlagents/trainers/tests/test_stats.py index e89b934869..1eaff42e15 100644 --- a/ml-agents/mlagents/trainers/tests/test_stats.py +++ b/ml-agents/mlagents/trainers/tests/test_stats.py @@ -12,6 +12,7 @@ StatsSummary, GaugeWriter, ConsoleWriter, + WandbWriter, StatsPropertyType, StatsAggregationMethod, ) @@ -248,3 +249,20 @@ def test_selfplay_console_writer(self): self.assertIn( "Mean Reward: 1.000. Std of Reward: 0.000. Training.", cm.output[0] ) + + +class WandbWriterTest(unittest.TestCase): + def test_wandb_full(self): + category = "GeneralStuff" + config = {"caller" : "ml-agents"} + wandb_writer = WandbWriter(config=config) + wandb_writer.write_stats( + category = category, + values = { + "Environment/Cumulative Reward": -15, + "Is Training": True, + "Self-play/ELO": 1.0, + }, + step = 10, + ) +