@@ -723,30 +723,32 @@ def pretraining_is_using_legacy_granite_chat_template(ds: Dataset) -> bool:
723723 return False
724724
725725
726- def ensure_dataset_is_compatible_with_legacy_format (
727- sample : t .Dict [str , t .Any ],
728- ) -> t .Dict [str , t .Any ]:
726+ def ensure_dataset_is_compatible_with_legacy_format (batch : t .Dict [str , t .List [t .Any ]]) -> t .Dict [str , t .List [t .Any ]]:
729727 """
730- Given a sample that uses the legacy pre-training format, we unroll the samples into ones with the
731- original messages contents.
728+ Given a batch of samples using the legacy pre-training format, unroll the samples into ones with
729+ the original messages contents.
732730 """
733- # deepcopy to prevent re-referencing the existing objects
734- new_sample = {
735- "messages" : [],
736- "unmask" : sample .get ("unmask" , False ),
737- }
738- for msg in sample ["messages" ]:
739- if msg ["role" ] != "pretraining" :
740- new_sample ["messages" ].append (msg )
741- continue
731+ processed_messages = []
732+ unmask_flags = []
742733
743- # handle unmasking
744- new_sample ["messages" ].extend (
745- extract_messages_from_pretraining_text (msg ["content" ])
746- )
747- new_sample ["unmask" ] = True
734+ for messages , unmask_flag in zip (batch ["messages" ], batch .get ("unmask" , [False ] * len (batch ["messages" ]))):
735+ new_messages = []
736+ unmask = unmask_flag
748737
749- return new_sample
738+ for msg in messages :
739+ if msg ["role" ] != "pretraining" :
740+ new_messages .append (msg )
741+ else :
742+ new_messages .extend (extract_messages_from_pretraining_text (msg ["content" ]))
743+ unmask = True # if any pretraining message is found, set unmask to True
744+
745+ processed_messages .append (new_messages )
746+ unmask_flags .append (unmask )
747+
748+ return {
749+ "messages" : processed_messages ,
750+ "unmask" : unmask_flags ,
751+ }
750752
751753
752754def filter_samples_by_length (
@@ -876,6 +878,8 @@ def load_and_validate_dataset(data_path: str, num_procs: int) -> Dataset:
876878
877879 return data .map (
878880 ensure_dataset_is_compatible_with_legacy_format ,
881+ batched = True ,
882+ batch_size = 1000 ,
879883 num_proc = num_procs ,
880884 desc = "Ensuring dataset is compatible with legacy format." ,
881885 )
0 commit comments