From 55a896f5ad16f673acc60cafcd2e36ad238b0114 Mon Sep 17 00:00:00 2001 From: Max Luebbering Date: Thu, 11 Jul 2024 17:29:13 +0200 Subject: [PATCH] fix: fixed consume_dict --- .../subscriber_impl/results_subscriber.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index f8df3de8..09547b1f 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -1,11 +1,12 @@ from pathlib import Path +from typing import Any, Dict import rich -import wandb import yaml from rich.console import Group from rich.panel import Panel +import wandb from modalities.batch import EvaluationResultBatch from modalities.config.config import WandbMode from modalities.logging_broker.messages import Message @@ -17,6 +18,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker.""" pass + def consume_dict(self, mesasge_dict: Dict[str, Any]): + pass + class RichResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def __init__(self, num_ranks: int) -> None: @@ -46,6 +50,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]): if losses or metrics: rich.print(Panel(Group(*group_content))) + def consume_dict(self, mesasge_dict: Dict[str, Any]): + raise NotImplementedError + class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): """A subscriber object for the WandBEvaluationResult observable.""" @@ -68,8 +75,9 @@ def __init__( self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config") - def consume_key_value(self, key: str, value: str): - self.run.config[key] = value + def consume_dict(self, mesasge_dict: Dict[str, Any]): + for k, v in mesasge_dict.items(): + self.run.config[k] = v def consume_message(self, message: Message[EvaluationResultBatch]): """Consumes a message from a message broker."""