Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 41 additions & 30 deletions data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,22 @@ def dup_idx(self, queries):
return [idx for uid, idx in queries if uid in self.parent]


def get_remote_classes():
"""Get remote versions of classes with Ray decorators applied at runtime."""
def get_remote_classes(actor_memory: Optional[int] = None):
"""Get remote versions of classes with Ray decorators applied at runtime.

:param actor_memory: Memory reservation for EdgeBuffer and BTSUnionFind actors in bytes.
"""
# Apply ray.method decorator to get_next_id at runtime
IdGenerator.get_next_id = ray.method(num_returns=2)(IdGenerator.get_next_id)

remote_args = {"scheduling_strategy": "SPREAD"}
if actor_memory is not None:
remote_args["memory"] = actor_memory

return {
"IdGenerator": ray.remote(IdGenerator),
"EdgeBuffer": ray.remote(scheduling_strategy="SPREAD")(EdgeBuffer),
"BTSUnionFind": ray.remote(scheduling_strategy="SPREAD")(BTSUnionFind),
"EdgeBuffer": ray.remote(**remote_args)(EdgeBuffer),
"BTSUnionFind": ray.remote(**remote_args)(BTSUnionFind),
}


Expand Down Expand Up @@ -322,6 +329,8 @@ def __init__(
merge_batch_size: Optional[int] = 1000,
minhash_batch_size: Optional[Union[int, str]] = "auto",
memory_per_sample: Optional[float] = 0.1, # MB per sample
actor_memory: Optional[int] = None, # Memory per actor (bytes)
task_memory: Optional[int] = None, # Memory per map_batches task (bytes)
*args,
**kwargs,
):
Expand Down Expand Up @@ -376,6 +385,12 @@ def __init__(
:param memory_per_sample: estimated memory needed per sample in MB.
Used to calculate batch size based on available GPU memory.
Default is 0.1 MB per sample.
:param actor_memory: Memory reservation per BTSUnionFind/EdgeBuffer
actor in bytes. For billion-row scale, use 20_000_000_000 (20GB).
Default is None (no reservation).
:param task_memory: Memory reservation per map_batches task in bytes.
For billion-row scale, use 2_000_000_000 (2GB).
Default is None (no reservation).
"""

super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -490,6 +505,8 @@ def tokenization_func(text):
self.max_pending_filter_tasks = max_pending_filter_tasks
self.num_filter_task_returns = num_filter_task_returns
self.union_threshold = union_threshold
self.actor_memory = actor_memory
self.task_memory = task_memory

# Lazy initialization - actors created in _ensure_actors()
self._actors_initialized = False
Expand All @@ -514,9 +531,11 @@ def _ensure_actors(self):
self.merge_batch_size = min(self._merge_batch_size_config, self.union_find_parallel_num)

logger.info(f"union_find_parallel_num = {self.union_find_parallel_num}")
if self.actor_memory is not None:
logger.info(f"actor_memory = {self.actor_memory}")

# Create actors NOW when cluster has resources
remote_classes = get_remote_classes()
remote_classes = get_remote_classes(actor_memory=self.actor_memory)
self.remote_edge_buffers = [remote_classes["EdgeBuffer"].remote() for _ in range(self.union_find_parallel_num)]
self.union_find_list = [
remote_classes["BTSUnionFind"].remote(
Expand All @@ -530,12 +549,24 @@ def _ensure_actors(self):
for i in range(self.union_find_parallel_num)
]

# Wait for all actors to be ready before proceeding
ray.get(
[uf.__ray_ready__.remote() for uf in self.union_find_list]
+ [eb.__ray_ready__.remote() for eb in self.remote_edge_buffers]
)

empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32)
self.empty_hash_value = b"\x00\x00\x00\x00" + empty_hash_value.tobytes()
self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num)

self._actors_initialized = True

def _get_map_batches_kwargs(self):
kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True}
if self.task_memory is not None:
kwargs["memory"] = self.task_memory
return kwargs

def band_minhash(self, minhash_list, uid_list):
"""
Logic for creating and pusing LSH bands to the union find list
Expand Down Expand Up @@ -701,19 +732,11 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
concurrency=ray.data.ActorPoolStrategy(size=concurrency),
batch_size=batch_size,
)
dataset.map_batches(
band_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).write_parquet(tmp_dir)
dataset.map_batches(band_with_uid, **self._get_map_batches_kwargs()).write_parquet(tmp_dir)
del dataset
else:
logger.info("Using CPU for MinHash computation")
dataset.map_batches(
minhash_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).write_parquet(tmp_dir)
dataset.map_batches(minhash_with_uid, **self._get_map_batches_kwargs()).write_parquet(tmp_dir)
end_time = time.time()
logger.info(f"MinHash time = {end_time - start_time}")
new_dataset = ray.data.read_parquet(tmp_dir)
Expand All @@ -722,11 +745,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
end_time = time.time()
logger.info(f"merge time = {end_time - start_time}")
start_time = time.time()
result = new_dataset.map_batches(
self.filter_with_union_find,
batch_format="pyarrow",
zero_copy_batch=True,
)
result = new_dataset.map_batches(self.filter_with_union_find, **self._get_map_batches_kwargs())
end_time = time.time()
logger.info(f"filter time = {end_time - start_time}")
return result
Expand Down Expand Up @@ -766,21 +785,13 @@ def minhash_with_uid(table: pa.Table) -> pa.Table:
self.calc_minhash(table[self.text_key], uid_list)
return table

dataset.map_batches(
minhash_with_uid,
batch_format="pyarrow",
zero_copy_batch=True,
).materialize()
dataset.map_batches(minhash_with_uid, **self._get_map_batches_kwargs()).materialize()
end_time = time.time()
logger.info(f"MinHash time = {end_time - start_time}")

start_time = time.time()
self.merge()
end_time = time.time()
logger.info(f"merge time = {end_time - start_time}")
result = dataset.map_batches(
self.filter_with_union_find,
batch_format="pyarrow",
zero_copy_batch=True,
)
result = dataset.map_batches(self.filter_with_union_find, **self._get_map_batches_kwargs())
return result
Loading