diff --git a/data_juicer/utils/model_utils.py b/data_juicer/utils/model_utils.py index 95c86a5b56..546f920591 100644 --- a/data_juicer/utils/model_utils.py +++ b/data_juicer/utils/model_utils.py @@ -1564,6 +1564,14 @@ def get_model(model_key=None, rank=None, use_cuda=False): global MODEL_ZOO if model_key not in MODEL_ZOO: logger.debug(f"{model_key} not found in MODEL_ZOO ({mp.current_process().name})") + + # Configure thread limits in worker processes to prevent thread over-subscription + # when running with multiple processes (num_proc > 1) + if mp.current_process().name != "MainProcess": + from data_juicer.utils.process_utils import setup_worker_threads + + setup_worker_threads(num_threads=1) + if use_cuda and cuda_device_count() > 0: rank = rank if rank is not None else 0 rank = rank % cuda_device_count() diff --git a/data_juicer/utils/process_utils.py b/data_juicer/utils/process_utils.py index 45bbc9c38a..e57e3ffcff 100644 --- a/data_juicer/utils/process_utils.py +++ b/data_juicer/utils/process_utils.py @@ -18,6 +18,41 @@ # This leaves some memory for Ray's overhead and other system processes. _OPS_MEMORY_LIMIT_FRACTION = 1.0 +# Track whether worker threads have been configured +_WORKER_THREADS_CONFIGURED = False + + +def setup_worker_threads(num_threads=1): + """ + Configure thread limits for worker processes to prevent thread over-subscription. + + When running with multiple worker processes (e.g., num_proc > 1), each worker + using multiple threads leads to severe performance degradation due to thread + contention. This function limits threads per worker to prevent this issue. + + :param num_threads: Number of threads per worker process (default: 1) + """ + global _WORKER_THREADS_CONFIGURED + + # Only configure once per process + if _WORKER_THREADS_CONFIGURED: + return + + # Set PyTorch thread limits directly (works even after torch is imported) + try: + import torch + + torch.set_num_threads(num_threads) + torch.set_num_interop_threads(num_threads) + logger.debug(f"Set torch threads to {num_threads}") + except ImportError: + pass + except RuntimeError as e: + # torch.set_num_interop_threads can only be called once + logger.debug(f"Could not set torch interop threads: {e}") + + _WORKER_THREADS_CONFIGURED = True + def setup_mp(method=None): if mp.current_process().name != "MainProcess":