Skip to content

Commit 081e80d

Browse files
author
pytorchbot
committed
2025-11-03 nightly release (2f4b794)
1 parent 6f1e503 commit 081e80d

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -425,17 +425,11 @@ def _load_config_file(
425425
if not config_path:
426426
return {}
427427

428-
try:
429-
with open(config_path, "r") as f:
430-
if is_json:
431-
return json.load(f) or {}
432-
else:
433-
return yaml.safe_load(f) or {}
434-
except Exception as e:
435-
logger.error(
436-
f"Failed to load config because {e}. Proceeding without it."
437-
)
438-
return {}
428+
with open(config_path, "r") as f:
429+
if is_json:
430+
return json.load(f) or {}
431+
else:
432+
return yaml.safe_load(f) or {}
439433

440434
@functools.wraps(func)
441435
def wrapper() -> Any: # pyre-ignore [3]
@@ -479,7 +473,12 @@ def wrapper() -> Any: # pyre-ignore [3]
479473
# Merge the two dictionaries, JSON overrides YAML
480474
merged_defaults = {**yaml_defaults, **json_defaults}
481475

482-
seen_args = set() # track all --<name> we've added
476+
# track all --<name> we've added
477+
seen_args = {
478+
"json_config",
479+
"yaml_config",
480+
"loglevel",
481+
}
483482

484483
for _name, param in sig.parameters.items():
485484
cls = param.annotation
@@ -548,7 +547,12 @@ def wrapper() -> Any: # pyre-ignore [3]
548547
logger.info(config_instance)
549548

550549
loglevel = logging._nameToLevel[args.loglevel.upper()]
551-
logger.setLevel(loglevel)
550+
# Set loglevel for all existing loggers
551+
for existing_logger_name in logging.root.manager.loggerDict:
552+
existing_logger = logging.getLogger(existing_logger_name)
553+
existing_logger.setLevel(loglevel)
554+
# Also set the root logger
555+
logging.root.setLevel(loglevel)
552556

553557
return func(**kwargs)
554558

@@ -745,7 +749,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
745749
f"{output_dir}/stacks-cuda-{name}.stacks", "self_cuda_time_total"
746750
)
747751

748-
if memory_snapshot:
752+
if memory_snapshot and (all_rank_traces or rank == 0):
749753
torch.cuda.empty_cache()
750754
torch.cuda.memory._record_memory_history(
751755
max_entries=MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
@@ -771,7 +775,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
771775
else:
772776
torch.cuda.synchronize(rank)
773777

774-
if memory_snapshot:
778+
if memory_snapshot and (all_rank_traces or rank == 0):
775779
try:
776780
torch.cuda.memory._dump_snapshot(
777781
f"{output_dir}/memory-{name}-rank{rank}.pickle"
@@ -857,6 +861,7 @@ class BenchFuncConfig:
857861
export_stacks: bool = False
858862
all_rank_traces: bool = False
859863
memory_snapshot: bool = False
864+
loglevel: str = "WARNING"
860865

861866
# pyre-ignore [2]
862867
def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
@@ -873,6 +878,10 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
873878
"memory_snapshot": self.memory_snapshot,
874879
} | kwargs_to_override
875880

881+
def set_log_level(self) -> None:
882+
loglevel = logging._nameToLevel[self.loglevel.upper()]
883+
logging.root.setLevel(loglevel)
884+
876885

877886
def benchmark_func(
878887
name: str,

torchrec/distributed/benchmark/benchmark_train_pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def runner(
129129
torch.cuda.is_available() and torch.cuda.device_count() >= world_size
130130
), "CUDA not available or insufficient GPUs for the requested world_size"
131131

132-
torch.autograd.set_detect_anomaly(True)
132+
run_option.set_log_level()
133133
with MultiProcessContext(
134134
rank=rank,
135135
world_size=world_size,

0 commit comments

Comments
 (0)