Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size beginning to vary half way through epoch #179

Open
MarcoForte opened this issue Jun 21, 2024 · 6 comments
Open

Batch size beginning to vary half way through epoch #179

MarcoForte opened this issue Jun 21, 2024 · 6 comments
Labels
bug Something isn't working help wanted Extra attention is needed

Comments

@MarcoForte
Copy link

MarcoForte commented Jun 21, 2024

馃悰 Bug

Hello, I'm running into an issue where my batch size begins to vary half way through an epoch.

To Reproduce

I logged when it deviated from 64. It happens in all epochs, and when training single gpu also.

Screenshot 2024-06-21 at 11 32 28

Code sample

Unfortunately I can't share the code, but I will share as much as I can, and I can run many experiments.
I'm launching the training with torchrun --standalone --nnodes=1 --nproc-per-node=8 main.py
I use sets = [StreamingDataset(a),StreamingDataset(b))] and Dataloader(CombinedStreamingDataset(datasets=sets))
I launch the training through trainer.fit. drop_last=True

Expected behavior

Fixed batch size throughout epoch.

Environment

Using the ngc 23.05

Ubuntu 22.04 including Python 3.10
NVIDIA CUDA 12.4.1
NVIDIA cuBLAS 12.4.5.8
NVIDIA cuDNN 9.1.0.70
NVIDIA NCCL 2.21.5
lightning==2.3.0
litdata==0.2.12
8 x H100

@MarcoForte MarcoForte added bug Something isn't working help wanted Extra attention is needed labels Jun 21, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@tchaton
Copy link
Collaborator

tchaton commented Jun 21, 2024

Hey @MarcoForte. Fascinating, I have never seen this ;) Can you share a reproducible script with fake data ? Does this issue still happen if you use a single StreamingDataset ?

@MarcoForte
Copy link
Author

Cheers @tchaton, yeah it was a bit surprising 馃憖. Only noticed it since I was in torch.compile mode, and the recompilation was being triggered causing a big slowdown. Otherwise it is possible it could go unnoticed...
It did happen with the single StreamingDataset also, bypassing the CombinedStreamingDataset.

If I find a moment I'll try for a reproducible script, thanks

@tchaton
Copy link
Collaborator

tchaton commented Jun 22, 2024

Thanks a lot @MarcoForte. Looking forward for the code to debug it

@tchaton
Copy link
Collaborator

tchaton commented Jun 25, 2024

Hey @MarcoForte Any chance to provide a reproducible script ?

@tchaton
Copy link
Collaborator

tchaton commented Jun 27, 2024

Hey @MarcoForte

Unfortunately, I can't reproduce this issue on my end.

import os
from lightning_cloud.utils.data_connection import add_s3_connection
from lightning.data import StreamingDataset, StreamingDataLoader
from lightning.data.streaming.serializers import JPEGSerializer
import torchvision.transforms.v2 as T
import open_clip
from tqdm import tqdm

# 1. Add the prepared dataset to your teamspace
add_s3_connection("laoin-400m")

# 2. Create the streaming dataset
class LAIONStreamingDataset(StreamingDataset):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.tokenizer = open_clip.get_tokenizer('ViT-B-32', context_length=512) # You can use any tokenizer
        self.serializer = JPEGSerializer()
        self.preprocess = T.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711))

    def __getitem__(self, index):
        _, image, text, _, _, _ = super().__getitem__(index)
        image = self.serializer.deserialize(image).float()
        return self.preprocess(image)

dataset = LAIONStreamingDataset(input_dir="/teamspace/s3_connections/laoin-400m")
dataloader = StreamingDataLoader(dataset, batch_size=64, num_workers=os.cpu_count())

batch_size = 64

for batch in tqdm(dataloader):
    assert batch.shape[0] == batch_size

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Extra attention is needed
Projects
None yet
Development

No branches or pull requests

2 participants