Skip to content

Commit

Permalink
Merge pull request #48 from Modalities/feat/merge-pbin-files
Browse files Browse the repository at this point in the history
feat: merge utility for pbin files
  • Loading branch information
luzian-hahn committed Mar 7, 2024
2 parents 095e491 + 4821804 commit dd0db07
Show file tree
Hide file tree
Showing 10 changed files with 215 additions and 98 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -189,20 +189,20 @@ Alternatively, directly use `src/modalities/__main__.py do_stuff --config_file_p
The `MemMapDataset` requires an index file providing the necessary pointers into the raw data file. The `MemMapDataset` can create the index file lazily, however, it is advised to create it beforehand. This can be done by running

```sh
modalities create_memmap_index <path/to/jsonl/file>
modalities data create_raw_index <path/to/jsonl/file>
```

The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_memmap_index --help`.
The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data create_raw_index --help`.

## Packed Dataset Generator

The `PackedMemMapDatasetContinuous` and `PackedMemMapDatasetMegatron` require a packed data file. To create the data file, you first have to generate a `MemMapDataset` index file as described [above](#memmapdataset-index-generator). Assuming the index and raw data are located in the same directory, you can simply execute the following command:

```sh
modalities create_packed_data <path/to/jsonl/file>
modalities data pack_encoded_data <path/to/jsonl/file>
```

The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities create_packed_data --help`.
The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via `modalities data pack_encoded_data --help`.

### Packed Data Format

Expand Down
4 changes: 2 additions & 2 deletions benchmarks/dataloader/launch_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ measure_modalities_preparation() {
set -e
test -f $INPUT_DIR
rm -f ${INPUT_DIR/.jsonl/.idx}
modalities create_memmap_index $INPUT_DIR &> /dev/null
modalities data create_raw_index $INPUT_DIR &> /dev/null
echo "finished memmap index creation"
rm -f ${INPUT_DIR/.jsonl/.pbin}
modalities create_packed_data $INPUT_DIR &> /dev/null
modalities data pack_encoded_data $INPUT_DIR &> /dev/null
echo "finished memmap packing"
)
}
Expand Down
8 changes: 4 additions & 4 deletions docs/source/memmap.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ The :python:`MemMapDataset` requires an index file providing the necessary point

.. code-block:: bash
modalities create_memmap_index <path/to/jsonl/file>
modalities data create_raw_index <path/to/jsonl/file>
The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_memmap_index --help`.
The index will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data create_raw_index --help`.

Packed Dataset Generator
--------------------------------------------------------------------------------
Expand All @@ -25,9 +25,9 @@ The :python:`PackedMemMapDatasetContinuous` and :python:`PackedMemMapDatasetMega

.. code-block:: bash
modalities create_packed_data <path/to/jsonl/file>
modalities data pack_encoded_data <path/to/jsonl/file>
The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities create_packed_data --help`.
The packed data file will be created in the same directory as the raw data file. For further options you may look into the usage documentation via :bash:`modalities data pack_encoded_data --help`.

Packed Data Format
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Expand Down
8 changes: 4 additions & 4 deletions docs/source/quickstart.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,20 @@ To start a training you need to create memmap dataset out of a jsonl file first,
.. code-block:: bash
# Create memmap dataset from jsonl file.
modalities create_memmap_index <path/to/jsonl/file>
modalities data create_raw_index <path/to/jsonl/file>
# Create packed dataset.
modalities create_packed_data <path/to/jsonl/file>
modalities data pack_encoded_data <path/to/jsonl/file>
For example, using the lorem ipsum example:

.. code-block:: bash
# Create memmap dataset from jsonl file.
modalities create_memmap_index data/lorem_ipsum.jsonl
modalities data create_raw_index data/lorem_ipsum.jsonl
# Create packed dataset.
modalities create_packed_data data/lorem_ipsum.jsonl
modalities data pack_encoded_data data/lorem_ipsum.jsonl
Training
----------------------------------------------------
Expand Down
28 changes: 17 additions & 11 deletions examples/getting_started/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,29 +41,29 @@ Firstly, we create the dataset index via
cd modalities/examples/getting_started/

# train split
modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \
modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_train.idx \
data/raw/redpajama_v2_samples_512_train.jsonl

# test split
modalities create_memmap_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \
modalities data create_raw_index --index_path data/mem_map/redpajama_v2_samples_512_test.idx \
data/raw/redpajama_v2_samples_512_test.jsonl
```
In this step, we read the JSON file as a binary file, iterate over all characters and build up the sample index (char-wise start and end position for each JSON sample)
as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_memmap_index` command
as determined by the `\n` character positions. The sample index is stored in the specified `index_path`. Internally, the `create_raw_index` command
instantiates and calls the [IndexGenerator](https://github.com/Modalities/modalities/blob/main/src/modalities/dataloader/create_index.py#L14).

After having determined the index, we create the packed dataset as described below by leveraging the tokenizer, jsonl file and the created index.

```sh
# train split
modalities create_packed_data --jq_pattern .raw_content \
modalities data pack_encoded_data --jq_pattern .raw_content \
--index_path data/mem_map/redpajama_v2_samples_512_train.idx \
--dst_path data/mem_map/redpajama_v2_samples_512_train.pbin \
--tokenizer_file tokenizer/tokenizer.json \
data/raw/redpajama_v2_samples_512_train.jsonl

# test split
modalities create_packed_data --jq_pattern .raw_content \
modalities data pack_encoded_data --jq_pattern .raw_content \
--index_path data/mem_map/redpajama_v2_samples_512_test.idx \
--dst_path data/mem_map/redpajama_v2_samples_512_test.pbin \
--tokenizer_file tokenizer/tokenizer.json \
Expand All @@ -83,15 +83,21 @@ Technically, packed datasets are defined a self-contained format that stores the
**Packed MemMap File Format**

```
|--8-BYTES-HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----|
|--HEADER--|-------------------DATA-SEGMENT-------------------|----INDEX-SEGMENT----|
8 bytes header:
header:
===============
specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes,
the start and end position of each segment (i.e, header, data, index) is specified. Therefore, the theoretical maximum size of the data segment
is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes.
Contains two elements:
* Specifies the size of the data segment in bytes. Since the header size is fixed to 8 bytes,
the start and end position of each segment (i.e, header, data, index) is specified.
Therefore, the theoretical maximum size of the data segment
is 2^64 bytes = 18,446 peta bytes or 4600e+15 tokens or 4.6 quintillion tokens, given that a token has 4 bytes.
* The size of a each represented single token in the data segment in bytes.
This values is inferred from the source data of this `.pbin`
and depends solely on the tokenizer's vocabulary used for encoding.
A 4-byte integer is used for this.
Therefore the header is always 8+4=12 bytes long.
Data segment:
=============
Expand Down
56 changes: 49 additions & 7 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import ComponentsModel, ProcessGroupBackendType, TokenizerTypes, load_app_config_dict
from modalities.dataloader.create_index import IndexGenerator
from modalities.dataloader.create_packed_data import PackedDataGenerator
from modalities.dataloader.create_packed_data import EmbeddedStreamData, PackedDataGenerator, join_embedded_stream_data
from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader
from modalities.evaluator import Evaluator
from modalities.gym import Gym
Expand Down Expand Up @@ -72,15 +72,29 @@ def entry_point_generate_text(model_path, config_path, tokenizer_type, tokenizer
generate_text_main(model_path, config_path, tokenizer, max_new_tokens, chat)


@main.command(name="create_memmap_index")
@main.group(name="data")
def data():
"""
Collection of utilities to preprocess, analyse and modify training data.
"""
pass


@data.command(name="create_raw_index")
@click.argument("src_path", type=Path)
@click.option(
"--index_path",
type=Path,
default=None,
help="output path for index. will use parent directory of src_path if none.",
)
def entry_point_create_memmap_index(src_path, index_path):
def entry_point_data_create_raw_index(src_path, index_path):
"""
Utility for indexing a large jsonl-file's content.
Background is the ability to further process the respective file without loading it,
while splitting its content line-based. This step is necessary in advance of further processing like tokenization.
It is only necessary once for a jsonl-file and allows therefore different tokenizations without re-indexing.
"""
index_path = LargeFileLinesReader.default_index_path(src_path, index_path)
if index_path.exists():
raise ValueError("index already exists. delete it or specify different output folder.")
Expand All @@ -91,7 +105,7 @@ def entry_point_create_memmap_index(src_path, index_path):
generator.create_index(index_path)


@main.command(name="create_packed_data")
@data.command(name="pack_encoded_data")
@click.argument("src_path", type=Path)
@click.option(
"--dst_path",
Expand Down Expand Up @@ -133,9 +147,14 @@ def entry_point_create_memmap_index(src_path, index_path):
default=os.cpu_count(),
help="Specify the number of tokenization workers. Default is the number of available CPUs.",
)
def entry_point_create_packed_data(
src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern, num_cpus
):
def entry_point_pack_encoded_data(src_path, dst_path, index_path, tokenizer_type, tokenizer_file, jq_pattern, num_cpus):
"""
Utility to encode an indexed, large jsonl-file.
(see also `create_index` for more information)
Returns .pbin-file, which can be inserted into a training process directly
and does not require its original jsonl-file or the respective index file anymore.
"""
# TODO: if we want to use alternative entrypoints together with the ResolverRegistry,
# we can currently not rely on the existing class resolver.
# This is based on its connection to the overall `AppConfig`.
Expand All @@ -153,6 +172,29 @@ def entry_point_create_packed_data(
generator.run(dst_path)


@data.command(name="merge_packed_data")
@click.argument("src_paths", type=click.types.Path(exists=True, path_type=Path), nargs=-1, required=True)
@click.argument("target_path", type=click.types.Path(file_okay=False, dir_okay=False, path_type=Path))
def entry_point_merge_packed_data(src_paths, target_path):
"""
Utility for merging different pbin-files into one.
This is especially useful, if different datasets were at different points in time or if one encoding takes so long,
that the overall process was done in chunks.
It is important that the same tokenizer got used for all chunks.
Specify an arbitrary amount of pbin-files and/or directory containing such as input.
"""
input_files = []
for p in src_paths:
p: Path
if p.is_dir():
input_files.extend(p.glob("**/*.pbin"))
else:
input_files.append(p)
embedded_datasets = list(map(EmbeddedStreamData, input_files))
join_embedded_stream_data(embedded_datasets, target_path)


class Main:
def __init__(self, config_dict: Dict, config_path: Path) -> None:
self.config_dict = config_dict
Expand Down
101 changes: 88 additions & 13 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,27 @@
import logging
import math
import multiprocessing
import os
import pickle
import warnings
from pathlib import Path
from typing import Callable, List, Tuple
from typing import Callable, Iterator, List, Tuple

import jq
import numpy as np
from tqdm import tqdm
from transformers import PreTrainedTokenizer

from modalities.dataloader.large_file_lines_reader import LargeFileLinesReader

logger = logging.getLogger(__name__)


class EmptySampleError(RuntimeError):
pass


class PackedDataGenerator:
# amount of bytes to represent number of all tokens in dataset.
# If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation.
# Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides
DATA_SECTION_LENGTH_IN_BYTES = 8
TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4
HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES

def __init__(
self,
src_path: Path,
Expand Down Expand Up @@ -126,9 +123,13 @@ def writer():
with dst_path.open("wb") as f:
# allocate first self.header_size_in_bytes bytes for header (encodes length of data section)
# not possible to prepend header after determining size of data section
f.write((0).to_bytes(self.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big"))
f.write(self._token_size_in_bytes.to_bytes(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big"))
curr_offset = self.HEADER_SIZE_IN_BYTES
f.write((0).to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big"))
f.write(
self._token_size_in_bytes.to_bytes(
EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big"
)
)
curr_offset = EmbeddedStreamData.HEADER_SIZE_IN_BYTES

# write data section (tokens)
for tokens_as_bytes in tqdm(
Expand Down Expand Up @@ -160,9 +161,9 @@ def _process_thread(self, process_id: int):

def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]):
start_of_index_in_bytes = index_list[-1][0] + index_list[-1][1]
length_of_byte_encoded_data_section = start_of_index_in_bytes - self.HEADER_SIZE_IN_BYTES
length_of_byte_encoded_data_section = start_of_index_in_bytes - EmbeddedStreamData.HEADER_SIZE_IN_BYTES
data_section_length_in_bytes = length_of_byte_encoded_data_section.to_bytes(
self.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big"
EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big"
)
with dst_path.open("rb+") as fout:
fout.seek(0)
Expand All @@ -176,3 +177,77 @@ def _process_line(self, line: str) -> bytes:
if len(tokens) == 0:
raise EmptySampleError("Received empty sample...")
return b"".join(map(self._encoded_token_to_bytes, tokens)) + self._encoded_eos_token_as_bytes


class EmbeddedStreamData:
# amount of bytes to represent number of all tokens in dataset.
# If the amount exceeds 2^(8*`header_size_in_bytes`), this requires adaptation.
# Decided to keep this constant, since a size of 8 bytes requires more data than the internet currently provides
DATA_SECTION_LENGTH_IN_BYTES = 8
TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES = 4
HEADER_SIZE_IN_BYTES = DATA_SECTION_LENGTH_IN_BYTES + TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES

def __init__(self, data_path: Path):
self._data_path = data_path
if not self._data_path.is_file():
raise FileNotFoundError(
f"Packed Data was not found at {self._data_path}."
f"Create on in advance by using `modalities data pack_encoded_data`."
)

with self._data_path.open("rb") as f:
# get number of bytes in data section
data_section_length_in_bytes = f.read(self.DATA_SECTION_LENGTH_IN_BYTES)
self.data_len = int.from_bytes(data_section_length_in_bytes, byteorder="big")

# get number of bytes for encoding a single token
f.seek(self.DATA_SECTION_LENGTH_IN_BYTES)
token_size_as_bytes = f.read(self.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES)
self.token_size_in_bytes = int.from_bytes(token_size_as_bytes, byteorder="big", signed=False)

# get index
f.seek(self.HEADER_SIZE_IN_BYTES + self.data_len)
pkl_encoded_index = f.read()
self.index_base = pickle.loads(pkl_encoded_index)

# initialize memmapped data section
self.data = np.memmap(self._data_path, mode="r", offset=self.HEADER_SIZE_IN_BYTES, shape=(self.data_len,))


def join_embedded_stream_data(stream_data: List[EmbeddedStreamData], target_file: Path, chunk_size: int = 2048):
if target_file.exists():
raise FileExistsError(f'Target File at "{target_file}" exists!')
data_len = sum(d.data_len for d in stream_data)
assert len({d.token_size_in_bytes for d in stream_data}) == 1, (
"Found different token representation sizes. This could indicate the usage of different tokenizers. "
"Not supported!"
)
token_size_in_bytes = stream_data[0].token_size_in_bytes

num_data_chunks = sum(math.ceil(d.data_len / chunk_size) for d in stream_data)
data_stream_generator = (d.data[i : i + chunk_size] for d in stream_data for i in range(0, d.data_len, chunk_size))

num_entries = sum(len(d.index_base) for d in stream_data)

def index_stream_generator() -> Iterator[Tuple[int, int]]:
curr_offset = 0
for embedded_stream_data in stream_data:
for entry_offset, segment_length in embedded_stream_data.index_base:
yield entry_offset + curr_offset, segment_length
curr_offset += embedded_stream_data.data_len
curr_offset -= embedded_stream_data.HEADER_SIZE_IN_BYTES

with target_file.open("wb") as fout:
fout.write(data_len.to_bytes(EmbeddedStreamData.DATA_SECTION_LENGTH_IN_BYTES, byteorder="big"))
fout.write(
token_size_in_bytes.to_bytes(EmbeddedStreamData.TOKEN_SIZE_DESCRIPTOR_LENGTH_IN_BYTES, byteorder="big")
)
for data_chunk in tqdm(data_stream_generator, total=num_data_chunks, desc="Writing Data Chunks..."):
fout.write(data_chunk)

joint_index = [entry for entry in tqdm(index_stream_generator(), total=num_entries, desc="Concatenating Index")]
pickled_index = pickle.dumps(joint_index)
pickled_index_as_chunks = (pickled_index[i : i + chunk_size] for i in range(0, len(pickled_index), chunk_size))
num_index_chunks = math.ceil(len(pickled_index) / chunk_size)
for index_chunk in tqdm(pickled_index_as_chunks, total=num_index_chunks, desc="Writing Index Chunks..."):
fout.write(index_chunk)
Loading

0 comments on commit dd0db07

Please sign in to comment.