Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: lint all files #42

Merged
merged 2 commits into from
Feb 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/modalities/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,15 @@ class EvaluationResultBatch(Batch):
losses: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())
metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())
throughput_metrics: Dict[str, torch.Tensor] = field(default_factory=lambda: dict())

def __str__(self) -> str:
eval_str = (
f"Evaluation result on dataset tag {self.dataloader_tag} after {self.global_train_sample_id + 1} samples:"
)
eval_str += "\n\nlosses: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.losses.items()])
eval_str += "\n\nmetrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.metrics.items()])
eval_str += "\n\nthroughput metrics: " + "\n\t".join([f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()])
eval_str += "\n\nthroughput metrics: " + "\n\t".join(
[f"{k}: {v.mean().item()}" for k, v in self.throughput_metrics.items()]
)
eval_str += "\n==============================================="
return eval_str
2 changes: 1 addition & 1 deletion src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,6 @@ def _process_line(self, eos_token_as_bytes: bytes, f: IO, line: str):
raise StopIteration
token_idx += 1
f.write(eos_token_as_bytes)
segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes
segment_length = (token_idx + 1) * self.TOKEN_SIZE_IN_BYTES # segment_length in bytes
self._index_list.append((self._curr_offset, segment_length))
self._curr_offset += segment_length
2 changes: 1 addition & 1 deletion src/modalities/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ class RunningEnvError(Exception):


class TimeRecorderStateError(Exception):
pass
pass
5 changes: 4 additions & 1 deletion src/modalities/logging_broker/message_broker.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
from abc import ABC, abstractmethod
from collections import defaultdict
from typing import Dict, List

from modalities.logging_broker.messages import Message, MessageTypes
from modalities.logging_broker.subscriber import MessageSubscriberIF
from typing import Dict, List


class MessageBrokerIF(ABC):
"""Interface for message broker objects."""

@abstractmethod
def add_subscriber(self, subscription: MessageTypes, subscriber: MessageSubscriberIF):
raise NotImplementedError
Expand All @@ -18,6 +20,7 @@ def distribute_message(self, message: Message):

class MessageBroker(MessageBrokerIF):
"""The MessageBroker sends notifications to its subscribers."""

def __init__(self) -> None:
self.subscriptions: Dict[MessageTypes, List[MessageSubscriberIF]] = defaultdict(list)

Expand Down
3 changes: 2 additions & 1 deletion src/modalities/logging_broker/publisher.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from modalities.logging_broker.message_broker import Message, MessageBroker

from modalities.logging_broker.message_broker import Message, MessageBroker
from modalities.logging_broker.messages import MessageTypes

T = TypeVar("T")
Expand All @@ -15,6 +15,7 @@ def publish_message(self, payload: T, message_type: MessageTypes):

class MessagePublisher(MessagePublisherIF[T]):
"""The MessagePublisher sends messages through a message broker."""

def __init__(
self,
message_broker: MessageBroker,
Expand Down
2 changes: 1 addition & 1 deletion src/modalities/logging_broker/subscriber.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar

from modalities.logging_broker.messages import Message

T = TypeVar("T")
Expand All @@ -11,4 +12,3 @@ class MessageSubscriberIF(ABC, Generic[T]):
@abstractmethod
def consume_message(self, message: Message[T]):
raise NotImplementedError

Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
from typing import Optional

import rich
import wandb
from rich.console import Group
from rich.panel import Panel

import wandb
from modalities.batch import EvaluationResultBatch
from modalities.config.config import AppConfig, WandbConfig
from modalities.logging_broker.messages import Message
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.config.config import AppConfig, WandbConfig


class DummyResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]):
def consume_message(self, message: Message[EvaluationResultBatch]):
Expand Down Expand Up @@ -49,8 +50,15 @@ def consume_message(self, message: Message[EvaluationResultBatch]):
class WandBEvaluationResultSubscriber(MessageSubscriberIF[EvaluationResultBatch]):
"""A subscriber object for the WandBEvaluationResult observable."""

def __init__(self, num_ranks: int, project: str, experiment_id: str, mode: WandbConfig.WandbMode, dir: Path,
experiment_config: Optional[AppConfig] = None) -> None:
def __init__(
self,
num_ranks: int,
project: str,
experiment_id: str,
mode: WandbConfig.WandbMode,
dir: Path,
experiment_config: Optional[AppConfig] = None,
) -> None:
super().__init__()
self.num_ranks = num_ranks

Expand Down Expand Up @@ -82,6 +90,4 @@ def consume_message(self, message: Message[EvaluationResultBatch]):
f"{eval_result.dataloader_tag} {metric_key}": metric_values
for metric_key, metric_values in eval_result.throughput_metrics.items()
}
wandb.log(
data=throughput_metrics, step=eval_result.global_train_sample_id + 1
)
wandb.log(data=throughput_metrics, step=eval_result.global_train_sample_id + 1)
17 changes: 10 additions & 7 deletions src/modalities/models/gpt2/preprocess_dataset.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
import os
from itertools import chain
from datasets import load_dataset
from transformers import GPT2TokenizerFast, GPT2LMHeadModel, GPT2Config

from accelerate import Accelerator
import os
from datasets import load_dataset
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast


def main():

def group_texts(examples):
# Concatenate all texts.
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
# We drop the small remainder, and if the total_length < block_size
# we exclude this batch and return an empty dict. We could add padding if the
# We drop the small remainder, and if the total_length < block_size
# we exclude this batch and return an empty dict. We could add padding if the
# model supported it instead of this drop, you can customize this part to your needs.
total_length = (total_length // block_size) * block_size
# Split by chunks of max_len.
result = {k: [t[i: i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items()}
result = {
k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result

Expand Down
4 changes: 3 additions & 1 deletion src/modalities/models/model.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from abc import abstractmethod
from typing import Dict
from modalities.batch import DatasetBatch, InferenceResultBatch

import torch
import torch.nn as nn

from modalities.batch import DatasetBatch, InferenceResultBatch


class NNModel(nn.Module):
def __init__(self, seed: int = None):
Expand Down
3 changes: 1 addition & 2 deletions src/modalities/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from rich.progress import Progress

with Progress() as progress:

task1 = progress.add_task("[red]Downloading...", total=1000)
task2 = progress.add_task("[green]Processing...", total=1000)
task3 = progress.add_task("[cyan]Cooking...", total=1000)
Expand All @@ -12,4 +11,4 @@
progress.update(task1, advance=0.5)
progress.update(task2, advance=0.3)
progress.update(task3, advance=0.9)
time.sleep(0.02)
time.sleep(0.02)