Skip to content

Commit dd45832

Browse files
committed
Using batch mapping for the load_and_validate_dataset function which uses data.map in process_messages_into_input_ids function in data_process.py
1 parent 5d510fb commit dd45832

File tree

1 file changed

+24
-20
lines changed

1 file changed

+24
-20
lines changed

src/instructlab/training/data_process.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -723,30 +723,32 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool:
723723
return False
724724

725725

726-
def ensure_dataset_is_compatible_with_legacy_format(
727-
sample: t.Dict[str, t.Any],
728-
) -> t.Dict[str, t.Any]:
726+
def ensure_dataset_is_compatible_with_legacy_format(batch: t.Dict[str, t.List[t.Any]]) -> t.Dict[str, t.List[t.Any]]:
729727
"""
730-
Given a sample that uses the legacy pre-training format, we unroll the samples into ones with the
731-
original messages contents.
728+
Given a batch of samples using the legacy pre-training format, unroll the samples into ones with
729+
the original messages contents.
732730
"""
733-
# deepcopy to prevent re-referencing the existing objects
734-
new_sample = {
735-
"messages": [],
736-
"unmask": sample.get("unmask", False),
737-
}
738-
for msg in sample["messages"]:
739-
if msg["role"] != "pretraining":
740-
new_sample["messages"].append(msg)
741-
continue
731+
processed_messages = []
732+
unmask_flags = []
742733

743-
# handle unmasking
744-
new_sample["messages"].extend(
745-
extract_messages_from_pretraining_text(msg["content"])
746-
)
747-
new_sample["unmask"] = True
734+
for messages, unmask_flag in zip(batch["messages"], batch.get("unmask", [False] * len(batch["messages"]))):
735+
new_messages = []
736+
unmask = unmask_flag
748737

749-
return new_sample
738+
for msg in messages:
739+
if msg["role"] != "pretraining":
740+
new_messages.append(msg)
741+
else:
742+
new_messages.extend(extract_messages_from_pretraining_text(msg["content"]))
743+
unmask = True # if any pretraining message is found, set unmask to True
744+
745+
processed_messages.append(new_messages)
746+
unmask_flags.append(unmask)
747+
748+
return {
749+
"messages": processed_messages,
750+
"unmask": unmask_flags,
751+
}
750752

751753

752754
def filter_samples_by_length(
@@ -876,6 +878,8 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset:
876878

877879
return data.map(
878880
ensure_dataset_is_compatible_with_legacy_format,
881+
batched=True,
882+
batch_size=1000,
879883
num_proc=num_procs,
880884
desc="Ensuring dataset is compatible with legacy format.",
881885
)

0 commit comments

Comments
 (0)