Skip to content

Commit

Permalink
Merge pull request #107 from Modalities/data_packing_performance_impr…
Browse files Browse the repository at this point in the history
…ovement

Improving throughput for PackedDataGenerator
  • Loading branch information
mali-git committed Apr 23, 2024
2 parents 27893fd + 73ac1d9 commit f903c6e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 25 deletions.
18 changes: 18 additions & 0 deletions config_files/data_preparation/packed_cc_en_2048.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
settings:
src_path: /workspaces/modalities/data/cc_en/raw/train.jsonl
dst_path: /workspaces/modalities/data/cc_en/processed/train.pbin
index_path: /workspaces/modalities/data/cc_en/processed/train.idx
jq_pattern: .text
num_cpus: ${node_env:num_cpus}
eod_token: <eod>
processing_batch_size: 1000
raw_samples_queue_size: 300
processed_samples_queue_size: 300

tokenizer:
component_key: tokenizer
variant_key: pretrained_sp_tokenizer
config:
tokenizer_model_file: /workspaces/modalities/data/tokenizer/sp_bpe_en/bpe_tokenizer.model
padding: false
max_length: 2048
12 changes: 7 additions & 5 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from modalities.batch import EvaluationResultBatch
from modalities.config.component_factory import ComponentFactory
from modalities.config.config import (
PackedDatasetComponentsModel,
PackedDatasetComponentsInstantiationModel,
ProcessGroupBackendType,
TokenizerTypes,
TrainingComponentsInstantiationModel,
Expand Down Expand Up @@ -134,18 +134,20 @@ def entry_point_pack_encoded_data(config_path: FilePath):
config = load_app_config_dict(config_path)
registry = Registry(COMPONENTS)
component_factory = ComponentFactory(registry=registry)
components: PackedDatasetComponentsModel = component_factory.build_components(
config_dict=config, components_model_type=PackedDatasetComponentsModel
components: PackedDatasetComponentsInstantiationModel = component_factory.build_components(
config_dict=config, components_model_type=PackedDatasetComponentsInstantiationModel
)

tokenizer = components.tokenizer
generator = PackedDataGenerator(
components.settings.src_path,
index_path=components.settings.index_path,
tokenizer=tokenizer,
tokenizer=components.tokenizer,
eod_token=components.settings.eod_token,
jq_pattern=components.settings.jq_pattern,
number_of_processes=components.settings.num_cpus,
processing_batch_size=components.settings.processing_batch_size,
raw_samples_queue_size=components.settings.raw_samples_queue_size,
processed_samples_queue_size=components.settings.processed_samples_queue_size,
)
generator.run(components.settings.dst_path)

Expand Down
5 changes: 4 additions & 1 deletion src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ class PackedDatasetSettings(BaseModel):
jq_pattern: str
num_cpus: Annotated[int, Field(strict=True, ge=1)] = os.cpu_count()
eod_token: str
processing_batch_size: Annotated[int, Field(strict=True, ge=1)]
raw_samples_queue_size: Annotated[int, Field(strict=True, ge=1)]
processed_samples_queue_size: Annotated[int, Field(strict=True, ge=1)]


class TrainingSettings(BaseModel):
Expand Down Expand Up @@ -422,7 +425,7 @@ class TrainingComponentsInstantiationModel(BaseModel):
settings: TrainingSettings


class PackedDatasetComponentsModel(BaseModel):
class PackedDatasetComponentsInstantiationModel(BaseModel):
tokenizer: PydanticTokenizerIFType
settings: PackedDatasetSettings

Expand Down
83 changes: 64 additions & 19 deletions src/modalities/dataloader/create_packed_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ def __init__(
eod_token: str,
number_of_processes: int,
jq_pattern: str,
processing_batch_size: int,
raw_samples_queue_size: int,
processed_samples_queue_size: int,
index_path: Optional[FilePath] = None,
):
"""
Expand All @@ -39,6 +42,8 @@ def __init__(
and length of samples given in `src_path`.
If not defined, an index file next to `src_path` is picked,
by replacing its suffix with ".idx".
:processing_batch_size: The size of the batches that the workers process
(has nothing to do with batch size during training!).
:param tokenizer: PretrainedTokenizer object, which is used to pre-tokenize the provided data in `src_path`.
Tokenization is necessary to work on final lengths of token sequences.
:param jq_pattern: jq-pattern applied on every jsonl-entry. Results are afterwards tokenized and packed
Expand All @@ -53,8 +58,10 @@ def __init__(
self._number_of_processes = number_of_processes
self._reader = LargeFileLinesReader(src_path, index_path=index_path)
self._total_num_of_tokens = 0
self._tokens_write_queue = multiprocessing.Queue()
self._raw_samples_queue = multiprocessing.Queue(maxsize=raw_samples_queue_size)
self.processed_samples_queue = multiprocessing.Queue(maxsize=processed_samples_queue_size)
self._exception_buffer = []
self.processing_batch_size = processing_batch_size

@staticmethod
def _get_required_num_of_bytes_to_repr(int_to_get_repr: int) -> int:
Expand Down Expand Up @@ -93,6 +100,9 @@ def run(self, dst_path: Optional[Path] = None):
raise self._exception_buffer[0]

def _launch_parallelized_workers(self, dst_path: Path):
reader = multiprocessing.Process(target=self._reader_thread())
reader.start()

writer = multiprocessing.Process(target=self._writer_thread(dst_path))
writer.start()
processor_threads = [
Expand All @@ -106,16 +116,16 @@ def _launch_parallelized_workers(self, dst_path: Path):
writer.join()

def _stop_processing(self):
self._tokens_write_queue.put(None)
self.processed_samples_queue.put(None)

def _generator_for_tokens_to_get_written(self):
while True:
if self._check_for_parallel_errors():
return
tokens = self._tokens_write_queue.get()
if tokens is None:
batch = self.processed_samples_queue.get()
if batch is None:
break
yield tokens
yield batch

def _check_for_parallel_errors(self) -> bool:
return bool(self._exception_buffer)
Expand All @@ -135,32 +145,67 @@ def writer():
curr_offset = EmbeddedStreamData.HEADER_SIZE_IN_BYTES

# write data section (tokens)
for tokens_as_bytes in tqdm(
self._generator_for_tokens_to_get_written(), desc="Processed Samples", total=len(self._reader)
):
f.write(tokens_as_bytes)
segment_length = len(tokens_as_bytes)
index_list.append((curr_offset, segment_length))
curr_offset += segment_length

pbar = tqdm(total=len(self._reader), desc="Processed batches")
for batch in self._generator_for_tokens_to_get_written():
# write the tokens for each document
for tokens_as_bytes in batch:
f.write(tokens_as_bytes)
segment_length = len(tokens_as_bytes)
index_list.append((curr_offset, segment_length))
curr_offset += segment_length
pbar.update(len(batch))
# write index
f.write(pickle.dumps(index_list))

self._update_data_length_in_pre_allocated_header(dst_path, index_list)

return writer

def _reader_thread(self) -> Callable:
def reader():
batch = []
for line_id, line in tqdm(enumerate(self._reader), desc="Reading jsonl", disable=True):
# line = self._reader[line_id]
batch.append((line_id, line))
if len(batch) % self.processing_batch_size == 0:
self._raw_samples_queue.put(batch)
batch = []

# add the remaining samples
if len(batch) > 0:
self._raw_samples_queue.put(batch)

for _ in range(self._number_of_processes):
self._raw_samples_queue.put(None)

return reader

def _process_thread(self, process_id: int):
if self._check_for_parallel_errors():
return
for idx in range(process_id, len(self._reader), self._number_of_processes):
line = self._reader[idx]

while True:
if self._check_for_parallel_errors():
return
batch = self._raw_samples_queue.get()
if batch is None:
break

try:
self._tokens_write_queue.put(self._process_line(line))
batch_processed = []
for line_id, line in batch:
processed_line = self._process_line(line, process_id)
batch_processed.append(processed_line)
self.processed_samples_queue.put(batch_processed)
except EmptySampleError:
warnings.warn(f"Encountered empty sample in line {idx} of file {self.src_path}")
warnings.warn(
f"Encountered empty sample in line {line_id} of file {self.src_path} within process {process_id}"
)
except Exception as exception:
warnings.warn(f"could not process line of number {idx}. Raised the following error: {exception=}")
warnings.warn(
f"Could not process line of number {line_id} within process {process_id}. "
f"Raised the following error: {exception=}"
)

def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list: List[Tuple[int, int]]):
start_of_index_in_bytes = index_list[-1][0] + index_list[-1][1]
Expand All @@ -172,7 +217,7 @@ def _update_data_length_in_pre_allocated_header(self, dst_path: Path, index_list
fout.seek(0)
fout.write(data_section_length_in_bytes)

def _process_line(self, line: str) -> bytes:
def _process_line(self, line: str, process_id: int) -> bytes:
jq_retrieved_text = self.jq_filter.input_text(line).first()
if jq_retrieved_text is None:
raise ValueError(f"jq was not able to find anything using the expression: {self.jq_filter}")
Expand Down

0 comments on commit f903c6e

Please sign in to comment.