diff --git a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py index 5583fc9627..1bb959b8d3 100644 --- a/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py +++ b/data_juicer/ops/deduplicator/ray_bts_minhash_deduplicator.py @@ -224,15 +224,22 @@ def dup_idx(self, queries): return [idx for uid, idx in queries if uid in self.parent] -def get_remote_classes(): - """Get remote versions of classes with Ray decorators applied at runtime.""" +def get_remote_classes(actor_memory: Optional[int] = None): + """Get remote versions of classes with Ray decorators applied at runtime. + + :param actor_memory: Memory reservation for EdgeBuffer and BTSUnionFind actors in bytes. + """ # Apply ray.method decorator to get_next_id at runtime IdGenerator.get_next_id = ray.method(num_returns=2)(IdGenerator.get_next_id) + remote_args = {"scheduling_strategy": "SPREAD"} + if actor_memory is not None: + remote_args["memory"] = actor_memory + return { "IdGenerator": ray.remote(IdGenerator), - "EdgeBuffer": ray.remote(scheduling_strategy="SPREAD")(EdgeBuffer), - "BTSUnionFind": ray.remote(scheduling_strategy="SPREAD")(BTSUnionFind), + "EdgeBuffer": ray.remote(**remote_args)(EdgeBuffer), + "BTSUnionFind": ray.remote(**remote_args)(BTSUnionFind), } @@ -322,6 +329,8 @@ def __init__( merge_batch_size: Optional[int] = 1000, minhash_batch_size: Optional[Union[int, str]] = "auto", memory_per_sample: Optional[float] = 0.1, # MB per sample + actor_memory: Optional[int] = None, # Memory per actor (bytes) + task_memory: Optional[int] = None, # Memory per map_batches task (bytes) *args, **kwargs, ): @@ -376,6 +385,12 @@ def __init__( :param memory_per_sample: estimated memory needed per sample in MB. Used to calculate batch size based on available GPU memory. Default is 0.1 MB per sample. + :param actor_memory: Memory reservation per BTSUnionFind/EdgeBuffer + actor in bytes. For billion-row scale, use 20_000_000_000 (20GB). + Default is None (no reservation). + :param task_memory: Memory reservation per map_batches task in bytes. + For billion-row scale, use 2_000_000_000 (2GB). + Default is None (no reservation). """ super().__init__(*args, **kwargs) @@ -490,6 +505,8 @@ def tokenization_func(text): self.max_pending_filter_tasks = max_pending_filter_tasks self.num_filter_task_returns = num_filter_task_returns self.union_threshold = union_threshold + self.actor_memory = actor_memory + self.task_memory = task_memory # Lazy initialization - actors created in _ensure_actors() self._actors_initialized = False @@ -514,9 +531,11 @@ def _ensure_actors(self): self.merge_batch_size = min(self._merge_batch_size_config, self.union_find_parallel_num) logger.info(f"union_find_parallel_num = {self.union_find_parallel_num}") + if self.actor_memory is not None: + logger.info(f"actor_memory = {self.actor_memory}") # Create actors NOW when cluster has resources - remote_classes = get_remote_classes() + remote_classes = get_remote_classes(actor_memory=self.actor_memory) self.remote_edge_buffers = [remote_classes["EdgeBuffer"].remote() for _ in range(self.union_find_parallel_num)] self.union_find_list = [ remote_classes["BTSUnionFind"].remote( @@ -530,12 +549,24 @@ def _ensure_actors(self): for i in range(self.union_find_parallel_num) ] + # Wait for all actors to be ready before proceeding + ray.get( + [uf.__ray_ready__.remote() for uf in self.union_find_list] + + [eb.__ray_ready__.remote() for eb in self.remote_edge_buffers] + ) + empty_hash_value = np.full((self.num_rows_per_band,), MAX_HASH, dtype=np.uint32) self.empty_hash_value = b"\x00\x00\x00\x00" + empty_hash_value.tobytes() self.empty_hash_table_id = int(MAX_HASH % self.union_find_parallel_num) self._actors_initialized = True + def _get_map_batches_kwargs(self): + kwargs = {"batch_format": "pyarrow", "zero_copy_batch": True} + if self.task_memory is not None: + kwargs["memory"] = self.task_memory + return kwargs + def band_minhash(self, minhash_list, uid_list): """ Logic for creating and pusing LSH bands to the union find list @@ -701,19 +732,11 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: concurrency=ray.data.ActorPoolStrategy(size=concurrency), batch_size=batch_size, ) - dataset.map_batches( - band_with_uid, - batch_format="pyarrow", - zero_copy_batch=True, - ).write_parquet(tmp_dir) + dataset.map_batches(band_with_uid, **self._get_map_batches_kwargs()).write_parquet(tmp_dir) del dataset else: logger.info("Using CPU for MinHash computation") - dataset.map_batches( - minhash_with_uid, - batch_format="pyarrow", - zero_copy_batch=True, - ).write_parquet(tmp_dir) + dataset.map_batches(minhash_with_uid, **self._get_map_batches_kwargs()).write_parquet(tmp_dir) end_time = time.time() logger.info(f"MinHash time = {end_time - start_time}") new_dataset = ray.data.read_parquet(tmp_dir) @@ -722,11 +745,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: end_time = time.time() logger.info(f"merge time = {end_time - start_time}") start_time = time.time() - result = new_dataset.map_batches( - self.filter_with_union_find, - batch_format="pyarrow", - zero_copy_batch=True, - ) + result = new_dataset.map_batches(self.filter_with_union_find, **self._get_map_batches_kwargs()) end_time = time.time() logger.info(f"filter time = {end_time - start_time}") return result @@ -766,11 +785,7 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: self.calc_minhash(table[self.text_key], uid_list) return table - dataset.map_batches( - minhash_with_uid, - batch_format="pyarrow", - zero_copy_batch=True, - ).materialize() + dataset.map_batches(minhash_with_uid, **self._get_map_batches_kwargs()).materialize() end_time = time.time() logger.info(f"MinHash time = {end_time - start_time}") @@ -778,9 +793,5 @@ def minhash_with_uid(table: pa.Table) -> pa.Table: self.merge() end_time = time.time() logger.info(f"merge time = {end_time - start_time}") - result = dataset.map_batches( - self.filter_with_union_find, - batch_format="pyarrow", - zero_copy_batch=True, - ) + result = dataset.map_batches(self.filter_with_union_find, **self._get_map_batches_kwargs()) return result