Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions data_juicer/utils/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1564,6 +1564,14 @@ def get_model(model_key=None, rank=None, use_cuda=False):
global MODEL_ZOO
if model_key not in MODEL_ZOO:
logger.debug(f"{model_key} not found in MODEL_ZOO ({mp.current_process().name})")

# Configure thread limits in worker processes to prevent thread over-subscription
# when running with multiple processes (num_proc > 1)
if mp.current_process().name != "MainProcess":
from data_juicer.utils.process_utils import setup_worker_threads

setup_worker_threads(num_threads=1)

if use_cuda and cuda_device_count() > 0:
rank = rank if rank is not None else 0
rank = rank % cuda_device_count()
Expand Down
35 changes: 35 additions & 0 deletions data_juicer/utils/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,41 @@
# This leaves some memory for Ray's overhead and other system processes.
_OPS_MEMORY_LIMIT_FRACTION = 1.0

# Track whether worker threads have been configured
_WORKER_THREADS_CONFIGURED = False


def setup_worker_threads(num_threads=1):
"""
Configure thread limits for worker processes to prevent thread over-subscription.

When running with multiple worker processes (e.g., num_proc > 1), each worker
using multiple threads leads to severe performance degradation due to thread
contention. This function limits threads per worker to prevent this issue.

:param num_threads: Number of threads per worker process (default: 1)
"""
global _WORKER_THREADS_CONFIGURED

# Only configure once per process
if _WORKER_THREADS_CONFIGURED:
return

# Set PyTorch thread limits directly (works even after torch is imported)
try:
import torch

torch.set_num_threads(num_threads)
torch.set_num_interop_threads(num_threads)
logger.debug(f"Set torch threads to {num_threads}")
except ImportError:
pass
except RuntimeError as e:
# torch.set_num_interop_threads can only be called once
logger.debug(f"Could not set torch interop threads: {e}")

_WORKER_THREADS_CONFIGURED = True


def setup_mp(method=None):
if mp.current_process().name != "MainProcess":
Expand Down