Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
386 changes: 386 additions & 0 deletions data_juicer/core/data/load_strategy.py

Large diffs are not rendered by default.

126 changes: 89 additions & 37 deletions data_juicer/core/ray_exporter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import os
from functools import partial

Expand All @@ -6,6 +7,7 @@
from data_juicer.utils.constant import Fields, HashKeys
from data_juicer.utils.file_utils import Sizes, byte_size_to_size_str
from data_juicer.utils.model_utils import filter_arguments
from data_juicer.utils.s3_utils import create_filesystem_from_args
from data_juicer.utils.webdataset_utils import reconstruct_custom_webdataset_format


Expand All @@ -22,6 +24,7 @@ class RayExporter:
"tfrecords",
"webdataset",
"lance",
"iceberg",
# 'images',
# 'numpy',
}
Expand Down Expand Up @@ -51,37 +54,27 @@ def __init__(
self.export_shard_size = export_shard_size
self.keep_stats_in_res_ds = keep_stats_in_res_ds
self.keep_hashes_in_res_ds = keep_hashes_in_res_ds
self.export_format = self._get_export_format(export_path) if export_type is None else export_type

if export_type:
self.export_format = export_type
elif export_path:
self.export_format = self._get_export_format(export_path)
else:
raise ValueError("Either export_path or export_type should be provided.")
if self.export_format not in self._SUPPORTED_FORMATS:
raise NotImplementedError(
f'export data format "{self.export_format}" is not supported '
f"for now. Only support {self._SUPPORTED_FORMATS}. Please check export_type or export_path."
)
self.export_extra_args = kwargs if kwargs is not None else {}

# Check if export_path is S3 and create filesystem if needed
self.s3_filesystem = None
if export_path.startswith("s3://"):
# Extract AWS credentials from export_extra_args (if provided)
s3_config = {}
if "aws_access_key_id" in self.export_extra_args:
s3_config["aws_access_key_id"] = self.export_extra_args.pop("aws_access_key_id")
if "aws_secret_access_key" in self.export_extra_args:
s3_config["aws_secret_access_key"] = self.export_extra_args.pop("aws_secret_access_key")
if "aws_session_token" in self.export_extra_args:
s3_config["aws_session_token"] = self.export_extra_args.pop("aws_session_token")
if "aws_region" in self.export_extra_args:
s3_config["aws_region"] = self.export_extra_args.pop("aws_region")
if "endpoint_url" in self.export_extra_args:
s3_config["endpoint_url"] = self.export_extra_args.pop("endpoint_url")

# Create PyArrow S3FileSystem with credentials
# This matches the pattern used in RayS3DataLoadStrategy
from data_juicer.utils.s3_utils import create_pyarrow_s3_filesystem

self.s3_filesystem = create_pyarrow_s3_filesystem(s3_config)
logger.info(f"Detected S3 export path: {export_path}. S3 filesystem configured.")
fs_args = copy.deepcopy(self.export_extra_args)
self.fs = create_filesystem_from_args(export_path, fs_args)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking if the returned fs is None might be necessary.

self._check_shard_size()

def _check_shard_size(self):
if self.export_shard_size == 0:
return
self.max_shard_size_str = ""

# get the string format of shard size
Expand Down Expand Up @@ -149,22 +142,30 @@ def _export_impl(self, dataset, export_path, columns=None):
if len(removed_fields):
dataset = dataset.drop_columns(removed_fields)

export_method = RayExporter._router()[self.export_format]
router = self._router()
if self.export_format in router:
export_method = router[self.export_format]
else:
export_method = RayExporter.write_others

export_kwargs = {
"export_extra_args": self.export_extra_args,
"export_format": self.export_format,
}
# Add S3 filesystem if available
if self.s3_filesystem is not None:
export_kwargs["export_extra_args"]["filesystem"] = self.s3_filesystem
# Add filesystem if available
if self.fs is not None:
export_kwargs["export_extra_args"]["filesystem"] = self.fs

if self.export_shard_size > 0:
# compute the min_rows_per_file for export methods
dataset_nbytes = dataset.size_bytes()
dataset_num_rows = dataset.count()
num_shards = int(dataset_nbytes / self.export_shard_size) + 1
num_shards = min(num_shards, dataset_num_rows)
rows_per_file = int(dataset_num_rows / num_shards)
export_kwargs["export_extra_args"]["min_rows_per_file"] = rows_per_file

if dataset_num_rows > 0:
num_shards = int(dataset_nbytes / self.export_shard_size) + 1
num_shards = min(num_shards, dataset_num_rows)
rows_per_file = max(1, int(dataset_num_rows / num_shards))
export_kwargs["export_extra_args"]["min_rows_per_file"] = rows_per_file

return export_method(dataset, export_path, **export_kwargs)

def export(self, dataset, columns=None):
Expand Down Expand Up @@ -236,7 +237,61 @@ def write_others(dataset, export_path, **kwargs):
# Add S3 filesystem if available
if "filesystem" in export_extra_args:
filtered_kwargs["filesystem"] = export_extra_args["filesystem"]
return write_method(export_path, **filtered_kwargs)
if export_path:
return write_method(export_path, **filtered_kwargs)
else:
return write_method(**filtered_kwargs)

@staticmethod
def write_iceberg(dataset, export_path, **kwargs):
"""
Export method for iceberg target tables.
Checks for table existence/connectivity. If check fails, safe fall-back to JSON.
"""
from pyiceberg.catalog import load_catalog
from pyiceberg.exceptions import NoSuchTableError

export_extra_args = kwargs.get("export_extra_args", {})
catalog_kwargs = export_extra_args.get("catalog_kwargs", {})
table_identifier = export_extra_args.get("table_identifier", export_path)

use_iceberg = False

try:
catalog = load_catalog(**catalog_kwargs)
catalog.load_table(table_identifier)
logger.info(f"Iceberg table {table_identifier} exists. Writing to Iceberg.")
use_iceberg = True

except NoSuchTableError as e:
logger.warning(
f"Iceberg target unavailable ({e.__class__.__name__}). Fallback to exporting to {export_path}..."
)
except Exception as e:
logger.error(f"Unexpected error checking Iceberg: {e}. Fallback to exporting to {export_path}...")

if use_iceberg:
try:
filtered_kwargs = filter_arguments(dataset.write_iceberg, export_extra_args)
return dataset.write_iceberg(table_identifier, **filtered_kwargs)
except Exception as e:
logger.error(f"Write to Iceberg failed during execution: {e}. Fallback to json...")

suffix = os.path.splitext(export_path)[-1].strip(".").lower()
if not suffix:
suffix = "jsonl"
logger.warning(f"No suffix found in {export_path}, using default fallback: {suffix}")

logger.info(f"Falling back to file export. Format: [{suffix}], Path: [{export_path}]")

fallback_kwargs = {}
if "filesystem" in export_extra_args:
fallback_kwargs["filesystem"] = export_extra_args["filesystem"]
if suffix in ["json", "jsonl"]:
return RayExporter.write_json(dataset, export_path, **fallback_kwargs)
else:
fallback_kwargs["export_format"] = suffix
return RayExporter.write_others(dataset, export_path, **fallback_kwargs)

# suffix to export method
@staticmethod
Expand All @@ -250,8 +305,5 @@ def _router():
"jsonl": RayExporter.write_json,
"json": RayExporter.write_json,
"webdataset": RayExporter.write_webdataset,
"parquet": RayExporter.write_others,
"csv": RayExporter.write_others,
"tfrecords": RayExporter.write_others,
"lance": RayExporter.write_others,
"iceberg": RayExporter.write_iceberg,
}
29 changes: 28 additions & 1 deletion data_juicer/utils/s3_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"""

import os
from typing import Dict, Tuple
from typing import Any, Dict, Tuple

import pyarrow.fs
from loguru import logger
Expand Down Expand Up @@ -117,3 +117,30 @@ def validate_s3_path(path: str) -> None:
"""
if not path.startswith("s3://"):
raise ValueError(f"S3 path must start with 's3://', got: {path}")


def create_filesystem_from_args(path: str, args: Dict[str, Any]):
"""
Create a PyArrow FileSystem based on the path prefix and parameters.
Automatically extract relevant credentials from args and remove them from args (using pop) to avoid polluting subsequent parameters.
"""
fs = None
if path.startswith("s3://"):
validate_s3_path(path)

s3_keys = ["aws_access_key_id", "aws_secret_access_key", "aws_session_token", "aws_region", "endpoint_url"]
s3_conf = {k: args.pop(k) for k in s3_keys if k in args}
fs = create_pyarrow_s3_filesystem(s3_conf)
logger.info(f"Detected S3 export path: {path}. S3 filesystem configured.")

elif path.startswith("hdfs://"):
import pyarrow.fs as pa_fs

hdfs_keys = ["host", "port", "user", "kerb_ticket", "extra_conf"]
hdfs_conf = {k: args.pop(k) for k in hdfs_keys if k in args}
if "port" in hdfs_conf:
hdfs_conf["port"] = int(hdfs_conf["port"])
fs = pa_fs.HadoopFileSystem(**hdfs_conf)
logger.info(f"Detected HDFS export path: {path}. HDFS filesystem configured.")
Comment on lines +128 to +144
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an extra else branch to raise a warning or error for unsupported prefix.


return fs
76 changes: 76 additions & 0 deletions demos/process_dist_sources/config/process_from_hdfs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# Process config example for dataset

# global parameters
project_name: 'demo'


np: 1 # number of subprocess to process your dataset

export_path: './outputs/demo/demo-processed-ray'

dataset:
configs:
- type: remote
source: hdfs
path: hdfs://your_hdfs_path/demo-dataset.jsonl
host: your_hdfs_host
port: 8020
user: your_username

# process schedule
# a list of several process operators with their arguments
process:
# Filter ops
- alphanumeric_filter: # filter text with alphabet/numeric ratio out of specific range.
tokenization: false # Whether to count the ratio of alphanumeric to the total number of tokens.
min_ratio: 0.0 # the min ratio of filter range
max_ratio: 0.9 # the max ratio of filter range
- average_line_length_filter: # filter text with the average length of lines out of specific range.
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- character_repetition_filter: # filter text with the character repetition ratio out of specific range
rep_len: 10 # repetition length for char-level n-gram
min_ratio: 0.0 # the min ratio of filter range
max_ratio: 0.5 # the max ratio of filter range
- flagged_words_filter: # filter text with the flagged-word ratio larger than a specific max value
lang: en # consider flagged words in what language
tokenization: false # whether to use model to tokenize documents
max_ratio: 0.0045 # the max ratio to filter text
flagged_words_dir: ./assets # directory to store flagged words dictionaries
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- language_id_score_filter: # filter text in specific language with language scores larger than a specific max value
lang: en # keep text in what language
min_score: 0.8 # the min language scores to filter text
- maximum_line_length_filter: # filter text with the maximum length of lines out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- perplexity_filter: # filter text with perplexity score out of specific range
lang: en # compute perplexity in what language
max_ppl: 1500 # the max perplexity score to filter text
- special_characters_filter: # filter text with special-char ratio out of specific range
min_ratio: 0.0 # the min ratio of filter range
max_ratio: 0.25 # the max ratio of filter range
- stopwords_filter: # filter text with stopword ratio smaller than a specific min value
lang: en # consider stopwords in what language
tokenization: false # whether to use model to tokenize documents
min_ratio: 0.3 # the min ratio to filter text
stopwords_dir: ./assets # directory to store stopwords dictionaries
use_words_aug: false # whether to augment words, especially for Chinese and Vietnamese
words_aug_group_sizes: [2] # the group size of words to augment
words_aug_join_char: "" # the join char between words to augment
- text_length_filter: # filter text with length out of specific range
min_len: 10 # the min length of filter range
max_len: 10000 # the max length of filter range
- words_num_filter: # filter text with number of words out of specific range
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
min_num: 10 # the min number of filter range
max_num: 10000 # the max number of filter range
- word_repetition_filter: # filter text with the word repetition ratio out of specific range
lang: en # sample in which language
tokenization: false # whether to use model to tokenize documents
rep_len: 10 # repetition length for word-level n-gram
min_ratio: 0.0 # the min ratio of filter range
max_ratio: 0.5 # the max ratio of filter range
Loading