diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 1902b016a4..4d93faa6df 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -231,22 +231,28 @@ def _get_sequence_parallel_dataset( ) -> Optional[Union["Dataset", "IterableDataset"]]: if data_args.shuffle_for_sequence_parallel: dataset = dataset.shuffle(seed=training_args.seed) - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), - desc="Running padding split on dataset", - ) + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running padding split on dataset", + ) + else: + kwargs = dict() pad_sequence_func = get_sequence_parallel_preprocess( data_args=data_args, model_args=model_args, stage="pad", tokenizer=tokenizer ) padded_dataset = dataset.map( pad_sequence_func, batched=True, batch_size=data_args.preprocessing_batch_size, **kwargs ) - kwargs = dict( - num_proc=data_args.preprocessing_num_workers, - load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), - desc="Running sequence parallel split on dataset", - ) + if not data_args.streaming: + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running sequence parallel split on dataset", + ) + else: + kwargs = dict() sp_dataset_func = get_sequence_parallel_preprocess( data_args=data_args, model_args=model_args, stage="split", tokenizer=tokenizer )