Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 47 additions & 30 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,22 @@ def dup_idx(self, queries):
return [idx for uid, idx in queries if uid in self.parent]


def get_remote_classes():
"""Get remote versions of classes with Ray decorators applied at runtime."""
def get_remote_classes(actor_memory: Optional[int] = None):
"""Get remote versions of classes with Ray decorators applied at runtime.

:param actor_memory: Memory reservation for EdgeBuffer and BTSUnionFind actors in bytes.
"""
# Apply ray.method decorator to get_next_id at runtime
IdGenerator.get_next_id = ray.method(num_returns=2)(IdGenerator.get_next_id)

remote_args = {"scheduling_strategy": "SPREAD"}
if actor_memory is not None:
remote_args["memory"] = actor_memory

return {
"IdGenerator": ray.remote(IdGenerator),
"EdgeBuffer": ray.remote(scheduling_strategy="SPREAD")(EdgeBuffer),
"BTSUnionFind": ray.remote(scheduling_strategy="SPREAD")(BTSUnionFind),
"EdgeBuffer": ray.remote(**remote_args)(EdgeBuffer),
"BTSUnionFind": ray.remote(**remote_args)(BTSUnionFind),
}


Expand Down Expand Up @@ -322,6 +329,8 @@ def __init__(
merge_batch_size: Optional[int] = 1000,
minhash_batch_size: Optional[Union[int, str]] = "auto",
memory_per_sample: Optional[float] = 0.1, # MB per sample
actor_memory: Optional[int] = None, # Memory per actor (bytes)
task_memory: Optional[int] = None, # Memory per map_batches task (bytes)
*args,
**kwargs,
):
Expand Down Expand Up @@ -376,6 +385,12 @@ def __init__(
:param memory_per_sample: estimated memory needed per sample in MB.
Used to calculate batch size based on available GPU memory.
Default is 0.1 MB per sample.
:param actor_memory: Memory reservation per BTSUnionFind/EdgeBuffer
actor in bytes. For billion-row scale, use 20_000_000_000 (20GB).
Default is None (no reservation).
:param task_memory: Memory reservation per map_batches task in bytes.
For billion-row scale, use 2_000_000_000 (2GB).
Default is None (no reservation).
"""

super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -490,6 +505,8 @@ def tokenization_func(text):
self.max_pending_filter_tasks = max_pending_filter_tasks
self.num_filter_task_returns = num_filter_task_returns
self.union_threshold = union_threshold
self.actor_memory = actor_memory
self.task_memory = task_memory

# Lazy initialization - actors created in _ensure_actors()
self._actors_initialized = False
Expand All @@ -514,9 +531,11 @@ def _ensure_actors(self):
self.merge_batch_size = min(self._merge_batch_size_config, self.union_find_parallel_num)

logger.info(f"union_find_parallel_num = {self.union_find_parallel_num}")
if self.actor_memory is not None:
logger.info(f"actor_memory = {self.actor_memory}")

# Create actors NOW when cluster has resources
remote_classes = get_remote_classes()
remote_classes = get_remote_classes(actor_memory=self.actor_memory)
self.remote_edge_buffers = [remote_classes["EdgeBuffer"].remote() for _ in range(self.union_find_parallel_num)]
self.union_find_list = [
remote_classes["BTSUnionFind"].remote(
Expand All @@ -530,6 +549,9 @@ def _ensure_actors(self):
for i in range(self.union_find_parallel_num)
]

# Wait for all actors to be ready before proceeding
ray.get([uf.__ray_ready__.remote() for uf in self.union_find_list])

empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32)
self.empty_hash_value = b"\x00\x00\x00\x00" + empty_hash_value.tobytes()
self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num)
Expand Down Expand Up @@ -701,19 +723,17 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
concurrency=ray.data.ActorPoolStrategy(size=concurrency),
batch_size=batch_size,
)
dataset.map_batches(
band_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).write_parquet(tmp_dir)
band_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
band_kwargs["memory"] = self.task_memory
dataset.map_batches(band_with_uid, **band_kwargs).write_parquet(tmp_dir)
del dataset
else:
logger.info("Using CPU for MinHash computation")
dataset.map_batches(
minhash_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).write_parquet(tmp_dir)
map_batches_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
map_batches_kwargs["memory"] = self.task_memory
dataset.map_batches(minhash_with_uid, **map_batches_kwargs).write_parquet(tmp_dir)
end_time = time.time()
logger.info(f"MinHash time = {end_time - start_time}")
new_dataset = ray.data.read_parquet(tmp_dir)
Expand All @@ -722,11 +742,10 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
end_time = time.time()
logger.info(f"merge time = {end_time - start_time}")
start_time = time.time()
result = new_dataset.map_batches(
self.filter_with_union_find,
batch_format="pyarrow",
zero_copy_batch=True,
)
filter_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
filter_kwargs["memory"] = self.task_memory
result = new_dataset.map_batches(self.filter_with_union_find, **filter_kwargs)
end_time = time.time()
logger.info(f"filter time = {end_time - start_time}")
return result
Expand Down Expand Up @@ -766,21 +785,19 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
self.calc_minhash(table[self.text_key], uid_list)
return table

dataset.map_batches(
minhash_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).materialize()
map_batches_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
map_batches_kwargs["memory"] = self.task_memory
dataset.map_batches(minhash_with_uid, **map_batches_kwargs).materialize()
end_time = time.time()
logger.info(f"MinHash time = {end_time - start_time}")

start_time = time.time()
self.merge()
end_time = time.time()
logger.info(f"merge time = {end_time - start_time}")
result = dataset.map_batches(
self.filter_with_union_find,
batch_format="pyarrow",
zero_copy_batch=True,
)
filter_kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
filter_kwargs["memory"] = self.task_memory
result = dataset.map_batches(self.filter_with_union_find, **filter_kwargs)
return result