@@ -425,17 +425,11 @@ def _load_config_file(
425425 if not config_path :
426426 return {}
427427
428- try :
429- with open (config_path , "r" ) as f :
430- if is_json :
431- return json .load (f ) or {}
432- else :
433- return yaml .safe_load (f ) or {}
434- except Exception as e :
435- logger .error (
436- f"Failed to load config because { e } . Proceeding without it."
437- )
438- return {}
428+ with open (config_path , "r" ) as f :
429+ if is_json :
430+ return json .load (f ) or {}
431+ else :
432+ return yaml .safe_load (f ) or {}
439433
440434 @functools .wraps (func )
441435 def wrapper () -> Any : # pyre-ignore [3]
@@ -479,7 +473,12 @@ def wrapper() -> Any: # pyre-ignore [3]
479473 # Merge the two dictionaries, JSON overrides YAML
480474 merged_defaults = {** yaml_defaults , ** json_defaults }
481475
482- seen_args = set () # track all --<name> we've added
476+ # track all --<name> we've added
477+ seen_args = {
478+ "json_config" ,
479+ "yaml_config" ,
480+ "loglevel" ,
481+ }
483482
484483 for _name , param in sig .parameters .items ():
485484 cls = param .annotation
@@ -548,7 +547,12 @@ def wrapper() -> Any: # pyre-ignore [3]
548547 logger .info (config_instance )
549548
550549 loglevel = logging ._nameToLevel [args .loglevel .upper ()]
551- logger .setLevel (loglevel )
550+ # Set loglevel for all existing loggers
551+ for existing_logger_name in logging .root .manager .loggerDict :
552+ existing_logger = logging .getLogger (existing_logger_name )
553+ existing_logger .setLevel (loglevel )
554+ # Also set the root logger
555+ logging .root .setLevel (loglevel )
552556
553557 return func (** kwargs )
554558
@@ -745,7 +749,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
745749 f"{ output_dir } /stacks-cuda-{ name } .stacks" , "self_cuda_time_total"
746750 )
747751
748- if memory_snapshot :
752+ if memory_snapshot and ( all_rank_traces or rank == 0 ) :
749753 torch .cuda .empty_cache ()
750754 torch .cuda .memory ._record_memory_history (
751755 max_entries = MAX_NUM_OF_MEM_EVENTS_PER_SNAPSHOT
@@ -771,7 +775,7 @@ def _trace_handler(prof: torch.profiler.profile) -> None:
771775 else :
772776 torch .cuda .synchronize (rank )
773777
774- if memory_snapshot :
778+ if memory_snapshot and ( all_rank_traces or rank == 0 ) :
775779 try :
776780 torch .cuda .memory ._dump_snapshot (
777781 f"{ output_dir } /memory-{ name } -rank{ rank } .pickle"
@@ -857,6 +861,7 @@ class BenchFuncConfig:
857861 export_stacks : bool = False
858862 all_rank_traces : bool = False
859863 memory_snapshot : bool = False
864+ loglevel : str = "WARNING"
860865
861866 # pyre-ignore [2]
862867 def benchmark_func_kwargs (self , ** kwargs_to_override ) -> Dict [str , Any ]:
@@ -873,6 +878,10 @@ def benchmark_func_kwargs(self, **kwargs_to_override) -> Dict[str, Any]:
873878 "memory_snapshot" : self .memory_snapshot ,
874879 } | kwargs_to_override
875880
881+ def set_log_level (self ) -> None :
882+ loglevel = logging ._nameToLevel [self .loglevel .upper ()]
883+ logging .root .setLevel (loglevel )
884+
876885
877886def benchmark_func (
878887 name : str ,
0 commit comments