@@ -589,7 +589,7 @@ def unmask_messages(
589589 )
590590
591591
592- def unmask_sample (
592+ def unmask_sample_single (
593593 sample : t .Dict [str , t .Any ], tokenizer : PreTrainedTokenizer
594594) -> ProcessedMessagesData :
595595 """
@@ -618,6 +618,25 @@ def unmask_sample(
618618 return unmask_messages (sample ["messages" ], tokenizer , unmask_roles )
619619
620620
621+ def unmask_sample (
622+ batch : t .Dict [str , t .List [t .Any ]], tokenizer : PreTrainedTokenizer
623+ ) -> t .Dict [str , t .List [t .Any ]]:
624+ input_ids_list = []
625+ labels_list = []
626+
627+ for i in range (len (batch ["messages" ])):
628+ sample = {key : batch [key ][i ] for key in batch }
629+ result = unmask_sample_single (sample , tokenizer )
630+
631+ input_ids_list .append (result ["input_ids" ])
632+ labels_list .append (result ["labels" ])
633+
634+ return {
635+ "input_ids" : input_ids_list ,
636+ "labels" : labels_list ,
637+ }
638+
639+
621640def extract_messages_from_pretraining_text (text : str ) -> t .List [Message ]:
622641 """
623642 Given a message from a pretraining message that was formatted using either the generic
@@ -925,6 +944,8 @@ def process_samples(
925944 # Process the dataset
926945 processed_data = data .map (
927946 process_sample_fn ,
947+ batched = True ,
948+ batch_size = 1000 ,
928949 num_proc = num_cpu_procs ,
929950 desc = "Converting samples into input_ids and labels..." ,
930951 load_from_cache_file = False ,
0 commit comments