Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data streaming support through mosaic-streaming #1525

Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ art
fschat==0.2.36
gradio==3.50.2
tensorboard
mosaicml-streaming
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: mosaicml-streaming should be an optional dependency. Is this the right way of adding it in this capacity?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm okay with it being a required dependency. If it causes issues down the line we can make it optional then. Hoping to keep things simpler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this version be locked?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed by 976bc13.


mamba-ssm==1.2.0.post1

Expand Down
3 changes: 3 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,5 +92,8 @@ def parse_requirements():
"galore": [
"galore_torch",
],
"mosaicml-streaming": [
fmv1992 marked this conversation as resolved.
Show resolved Hide resolved
"mosaicml-streaming",
],
winglian marked this conversation as resolved.
Show resolved Hide resolved
},
)
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class SFTDataset(BaseModel):
data_files: Optional[Union[str, List[str]]] = None
name: Optional[str] = None
ds_type: Optional[str] = None
streaming: Optional[bool] = None
train_on_split: Optional[str] = None

field: Optional[str] = None
Expand Down
54 changes: 50 additions & 4 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,49 @@
LOG = logging.getLogger("axolotl")


def load_streaming_dataset(config_dataset):
"""
Load a streaming dataset from a remote storage.

This function initializes a streaming dataset from a remote S3 bucket,
wraps the data into a generator, and then converts it into a Hugging Face
dataset for compatibility purposes.

Parameters:
- config_dataset (dict): Configuration dictionary that may contain settings necessary for initializing the dataset.

Returns:
- ds (datasets.Dataset): A Hugging Face dataset object that streams data from the specified remote location.
"""
# These imports are local due to the optionality of `mosaicml-streaming`.
from streaming import StreamingDataset
from datasets import Features, Value, Dataset
from functools import partial

# Initialize the `StreamingDataset` with configurations.
streaming_dataset = StreamingDataset(
local=None, remote=config_dataset.path, shuffle=True, batch_size=4
)

# Define dataset features according to the axolotl structure.
features = Features({"text": Value("string")})

# Shim between `StreamingDataset` and `Dataset`.
def generator_from_streaming_dataset(streaming_dataset):
yield from streaming_dataset

# Create a Hugging Face dataset from the generator.
#
# This is necessary because downstream functions use a different interface
# than `StreamingDataset` (e.g. the `features` attribute).
ds = Dataset.from_generator(
Copy link
Collaborator

@winglian winglian Apr 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This becomes an IterableDataset, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@winglian ,

Sorry for the delay here.

No, that was something that I wanted to verify but it looks like it goes to def process and everything is evaluated eagerly.

I started a draft like:

    def process(self, dataset):
        features = dataset.features.keys()
        map_kwargs = {}
        if self.prompt_tokenizer.supports_batched:
            map_kwargs["batched"] = True
            map_kwargs["batch_size"] = 100
        map_kwargs["desc"] = "Tokenizing Prompts"

        if isinstance(dataset, IterableDataset):
            dataset_wrapper = dataset.map(
                self.prompt_tokenizer.tokenize_prompt,
                remove_columns=features,
                keep_in_memory=self.keep_in_memory,
                **map_kwargs,
            )
        else:
            num_proc = min(
                64, self.process_count if self.process_count else os.cpu_count()
            )

            return dataset.map(
                self.prompt_tokenizer.tokenize_prompt,
                num_proc=num_proc,
                remove_columns=features,
                keep_in_memory=self.keep_in_memory,
                desc="Tokenizing Prompts",
                **map_kwargs,
            )

But I don't know whether that's a good idea. The .map API is different between Dataset (here) and IterableDataset (here).

Feel free to remove the "ready to merge" tag from this.

generator=partial(generator_from_streaming_dataset, streaming_dataset),
features=features,
)

return ds


def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
Expand Down Expand Up @@ -317,10 +360,13 @@ def for_d_in_datasets(dataset_configs):
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
if config_dataset.streaming:
ds = load_streaming_dataset(config_dataset)
else:
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
Expand Down
Loading