Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add data streaming support through mosaic-streaming #1525

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,20 @@ tokens: # these are delimiters

When you include these tokens in your axolotl config, axolotl adds these tokens to the tokenizer's vocabulary.

#### Streaming dataset

Use [mosaicml-streaming](https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start) to prepare your dataset for streaming the data. This allows for using "infinite" data sets. Just add `streaming: true` to your `datasets` entry:

```
datasets:
- ds_type: json
path: s3://my-bucket/datasets-path/
type: completion
streaming: true
```

Ensure that you have uploaded the dataset according to [mosaicml-streaming](https://github.com/mosaicml/streaming?tab=readme-ov-file#quick-start)'s format beforehand.

### Inference Playground

Axolotl allows you to load your model in an interactive terminal playground for quick experimentation.
Expand Down
1 change: 1 addition & 0 deletions docs/config.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ datasets:
shards: # Optional[int] number of shards to split data into
name: # Optional[str] name of dataset configuration to load
train_on_split: train # Optional[str] name of dataset split to load from
streaming: null # Optional[bool] whether to use `mosaicml-streaming`'s capabilities or not.

# Optional[str] fastchat conversation type, only used with type: sharegpt
conversation: # Options (see Conversation 'name'): https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ art
fschat==0.2.36
gradio==3.50.2
tensorboard
mosaicml-streaming==0.7.5

mamba-ssm==1.2.0.post1

Expand Down
1 change: 1 addition & 0 deletions src/axolotl/utils/config/models/input/v0_4_1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class SFTDataset(BaseModel):
data_files: Optional[Union[str, List[str]]] = None
name: Optional[str] = None
ds_type: Optional[str] = None
streaming: Optional[bool] = None
train_on_split: Optional[str] = None

field: Optional[str] = None
Expand Down
55 changes: 51 additions & 4 deletions src/axolotl/utils/data/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,50 @@
LOG = logging.getLogger("axolotl")


def load_streaming_dataset(config_dataset):
"""
Load a streaming dataset from a remote storage.

This function initializes a streaming dataset from a remote S3 bucket,
wraps the data into a generator, and then converts it into a Hugging Face
dataset for compatibility purposes.

Parameters:
- config_dataset (dict): Configuration dictionary that may contain settings necessary for initializing the dataset.

Returns:
- ds (datasets.Dataset): A Hugging Face dataset object that streams data from the specified remote location.
"""
# These imports are local due to the optionality of `mosaicml-streaming`.
from functools import partial

from datasets import Features, Value
from streaming import StreamingDataset

# Initialize the `StreamingDataset` with configurations.
streaming_dataset = StreamingDataset(
local=None, remote=config_dataset.path, shuffle=True, batch_size=4
)

# Define dataset features according to the axolotl structure.
features = Features({"text": Value("string")})

# Shim between `StreamingDataset` and `Dataset`.
def generator_from_streaming_dataset(streaming_dataset):
yield from streaming_dataset

# Create a Hugging Face dataset from the generator.
#
# This is necessary because downstream functions use a different interface
# than `StreamingDataset` (e.g. the `features` attribute).
ds = Dataset.from_generator( # pylint: disable=invalid-name
generator=partial(generator_from_streaming_dataset, streaming_dataset),
features=features,
)

return ds


def prepare_dataset(cfg, tokenizer):
prompters = []
if not cfg.pretraining_dataset:
Expand Down Expand Up @@ -317,10 +361,13 @@ def for_d_in_datasets(dataset_configs):
)
elif ds_from_cloud and remote_file_system:
if remote_file_system.isdir(config_dataset.path):
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
if config_dataset.streaming:
ds = load_streaming_dataset(config_dataset)
else:
ds = load_from_disk(
config_dataset.path,
storage_options=storage_options,
)
elif remote_file_system.isfile(config_dataset.path):
ds_type = get_ds_type(config_dataset)
ds = load_dataset(
Expand Down
Loading