From 5b60c2f771bd0563644b054d171999540fa9ffc0 Mon Sep 17 00:00:00 2001 From: Felix Stollenwerk Date: Tue, 30 Jan 2024 22:48:35 +0100 Subject: [PATCH] fix: lint all files --- src/modalities/batch.py | 5 ++++- src/modalities/config/config.py | 1 + .../dataloader/create_packed_data.py | 2 +- src/modalities/exceptions.py | 2 +- .../logging_broker/message_broker.py | 5 ++++- src/modalities/logging_broker/publisher.py | 3 ++- src/modalities/logging_broker/subscriber.py | 2 +- .../subscriber_impl/results_subscriber.py | 20 ++++++++++++------- .../models/gpt2/preprocess_dataset.py | 17 +++++++++------- src/modalities/models/model.py | 4 +++- src/modalities/test.py | 3 +-- 11 files changed, 41 insertions(+), 23 deletions(-) diff --git a/src/modalities/batch.py b/src/modalities/batch.py index bc6c62c0..7cf3f34e 100644 --- a/src/modalities/batch.py +++ b/src/modalities/batch.py @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch): losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict()) + def __str__(self) -> str: eval_str = ( f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:" ) eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()]) eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()]) - eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]) + eval_str += "\n\nthroughput metrics: " + "\n\t".join( + [f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()] + ) eval_str += "\n===============================================" return eval_str diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 0e166242..bfb353a2 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -284,6 +284,7 @@ class RunMode(Enum): FROM_SCRATCH = "FROM_SCRATCH" WARM_START = "WARM_START" + class ModalitiesSetupConfig(BaseModel): class WarmStartSettings(BaseModel): checkpoint_model_path: Path diff --git a/src/modalities/dataloader/create_packed_data.py b/src/modalities/dataloader/create_packed_data.py index 6e8d4d3c..826baf44 100644 --- a/src/modalities/dataloader/create_packed_data.py +++ b/src/modalities/dataloader/create_packed_data.py @@ -116,6 +116,6 @@ def _process_line(self, eos_token_as_bytes: bytes, f: IO, line: str): raise StopIteration token_idx += 1 f.write(eos_token_as_bytes) - segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes + segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes self._index_list.append((self._curr_offset, segment_length)) self._curr_offset += segment_length diff --git a/src/modalities/exceptions.py b/src/modalities/exceptions.py index c5e5e3a2..07e344d5 100644 --- a/src/modalities/exceptions.py +++ b/src/modalities/exceptions.py @@ -15,4 +15,4 @@ class RunningEnvError(Exception): class TimeRecorderStateError(Exception): - pass \ No newline at end of file + pass diff --git a/src/modalities/logging_broker/message_broker.py b/src/modalities/logging_broker/message_broker.py index d5f4aec2..7b38e58f 100644 --- a/src/modalities/logging_broker/message_broker.py +++ b/src/modalities/logging_broker/message_broker.py @@ -1,12 +1,14 @@ from abc import ABC, abstractmethod from collections import defaultdict +from typing import Dict, List + from modalities.logging_broker.messages import Message, MessageTypes from modalities.logging_broker.subscriber import MessageSubscriberIF -from typing import Dict, List class MessageBrokerIF(ABC): """Interface for message broker objects.""" + @abstractmethod def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF): raise NotImplementedError @@ -18,6 +20,7 @@ def distribute_message(self, message: Message): class MessageBroker(MessageBrokerIF): """The MessageBroker sends notifications to its subscribers.""" + def __init__(self) -> None: self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list) diff --git a/src/modalities/logging_broker/publisher.py b/src/modalities/logging_broker/publisher.py index 34ff834b..28cc27de 100644 --- a/src/modalities/logging_broker/publisher.py +++ b/src/modalities/logging_broker/publisher.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar -from modalities.logging_broker.message_broker import Message, MessageBroker +from modalities.logging_broker.message_broker import Message, MessageBroker from modalities.logging_broker.messages import MessageTypes T = TypeVar("T") @@ -15,6 +15,7 @@ def publish_message(self, payload: T, message_type: MessageTypes): class MessagePublisher(MessagePublisherIF[T]): """The MessagePublisher sends messages through a message broker.""" + def __init__( self, message_broker: MessageBroker, diff --git a/src/modalities/logging_broker/subscriber.py b/src/modalities/logging_broker/subscriber.py index 7e965b75..6b4e5c2d 100644 --- a/src/modalities/logging_broker/subscriber.py +++ b/src/modalities/logging_broker/subscriber.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar + from modalities.logging_broker.messages import Message T = TypeVar("T") @@ -11,4 +12,3 @@ class MessageSubscriberIF(ABC, Generic[T]): @abstractmethod def consume_message(self, message: Message[T]): raise NotImplementedError - diff --git a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py index b2965725..92fe0fc1 100644 --- a/src/modalities/logging_broker/subscriber_impl/results_subscriber.py +++ b/src/modalities/logging_broker/subscriber_impl/results_subscriber.py @@ -2,14 +2,15 @@ from typing import Optional import rich +import wandb from rich.console import Group from rich.panel import Panel -import wandb from modalities.batch import EvaluationResultBatch +from modalities.config.config import AppConfig, WandbConfig from modalities.logging_broker.messages import Message from modalities.logging_broker.subscriber import MessageSubscriberIF -from modalities.config.config import AppConfig, WandbConfig + class DummyResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): def consume_message(self, message: Message[EvaluationResultBatch]): @@ -49,8 +50,15 @@ def consume_message(self, message: Message[EvaluationResultBatch]): class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]): """A subscriber object for the WandBEvaluationResult observable.""" - def __init__(self, num_ranks: int, project: str, experiment_id: str, mode: WandbConfig.WandbMode, dir: Path, - experiment_config: Optional[AppConfig] = None) -> None: + def __init__( + self, + num_ranks: int, + project: str, + experiment_id: str, + mode: WandbConfig.WandbMode, + dir: Path, + experiment_config: Optional[AppConfig] = None, + ) -> None: super().__init__() self.num_ranks = num_ranks @@ -82,6 +90,4 @@ def consume_message(self, message: Message[EvaluationResultBatch]): f"{eval_result.dataloader_tag} {metric_key}": metric_values for metric_key, metric_values in eval_result.throughput_metrics.items() } - wandb.log( - data=throughput_metrics, step=eval_result.global_train_sample_id + 1 - ) + wandb.log(data=throughput_metrics, step=eval_result.global_train_sample_id + 1) diff --git a/src/modalities/models/gpt2/preprocess_dataset.py b/src/modalities/models/gpt2/preprocess_dataset.py index 99afb069..9ee372a1 100644 --- a/src/modalities/models/gpt2/preprocess_dataset.py +++ b/src/modalities/models/gpt2/preprocess_dataset.py @@ -1,21 +1,24 @@ +import os from itertools import chain -from datasets import load_dataset -from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config + from accelerate import Accelerator -import os +from datasets import load_dataset +from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast def main(): - def group_texts(examples): # Concatenate all texts. concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()} total_length = len(concatenated_examples[list(examples.keys())[0]]) - # We drop the small remainder, and if the total_length < block_size we exclude this batch and return an empty dict. - # We could add padding if the model supported it instead of this drop, you can customize this part to your needs. + # We drop the small remainder, and if total_length < block_size we exclude this batch and return an empty dict + # We could add padding if the model supported it instead of this drop, you can customize this part to your needs total_length = (total_length // block_size) * block_size # Split by chunks of max_len. - result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()} + result = { + k: [t[i : i + block_size] for i in range(0, total_length, block_size)] + for k, t in concatenated_examples.items() + } result["labels"] = result["input_ids"].copy() return result diff --git a/src/modalities/models/model.py b/src/modalities/models/model.py index d00a8043..511419b9 100644 --- a/src/modalities/models/model.py +++ b/src/modalities/models/model.py @@ -1,9 +1,11 @@ from abc import abstractmethod from typing import Dict -from modalities.batch import DatasetBatch, InferenceResultBatch + import torch import torch.nn as nn +from modalities.batch import DatasetBatch, InferenceResultBatch + class NNModel(nn.Module): def __init__(self, seed: int = None): diff --git a/src/modalities/test.py b/src/modalities/test.py index ea16a091..f81c3630 100644 --- a/src/modalities/test.py +++ b/src/modalities/test.py @@ -3,7 +3,6 @@ from rich.progress import Progress with Progress() as progress: - task1 = progress.add_task("[red]Downloading...", total=1000) task2 = progress.add_task("[green]Processing...", total=1000) task3 = progress.add_task("[cyan]Cooking...", total=1000) @@ -12,4 +11,4 @@ progress.update(task1, advance=0.5) progress.update(task2, advance=0.3) progress.update(task3, advance=0.9) - time.sleep(0.02) \ No newline at end of file + time.sleep(0.02)