Skip to content

Commit

Permalink
fix: fixed consume_dict
Browse files Browse the repository at this point in the history
  • Loading branch information
le1nux committed Jul 11, 2024
1 parent 06afc42 commit 55a896f
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from pathlib import Path
from typing import Any, Dict

import rich
import wandb
import yaml
from rich.console import Group
from rich.panel import Panel

import wandb
from modalities.batch import EvaluationResultBatch
from modalities.config.config import WandbMode
from modalities.logging_broker.messages import Message
Expand All @@ -17,6 +18,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]):
"""Consumes a message from a message broker."""
pass

def consume_dict(self, mesasge_dict: Dict[str, Any]):
pass


class RichResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]):
def __init__(self, num_ranks: int) -> None:
Expand Down Expand Up @@ -46,6 +50,9 @@ def consume_message(self, message: Message[EvaluationResultBatch]):
if losses or metrics:
rich.print(Panel(Group(*group_content)))

def consume_dict(self, mesasge_dict: Dict[str, Any]):
raise NotImplementedError


class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]):
"""A subscriber object for the WandBEvaluationResult observable."""
Expand All @@ -68,8 +75,9 @@ def __init__(

self.run.log_artifact(config_file_path, name=f"config_{wandb.run.id}", type="config")

def consume_key_value(self, key: str, value: str):
self.run.config[key] = value
def consume_dict(self, mesasge_dict: Dict[str, Any]):
for k, v in mesasge_dict.items():
self.run.config[k] = v

def consume_message(self, message: Message[EvaluationResultBatch]):
"""Consumes a message from a message broker."""
Expand Down

0 comments on commit 55a896f

Please sign in to comment.