diff --git a/py/torch_tensorrt/dynamo/utils.py b/py/torch_tensorrt/dynamo/utils.py index ee11e597a1..a85494239e 100644 --- a/py/torch_tensorrt/dynamo/utils.py +++ b/py/torch_tensorrt/dynamo/utils.py @@ -11,6 +11,7 @@ from torch._subclasses.fake_tensor import FakeTensor from torch_tensorrt._Device import Device from torch_tensorrt._enums import dtype +from torch_tensorrt._features import ENABLED_FEATURES from torch_tensorrt._Input import Input from torch_tensorrt.dynamo import _defaults from torch_tensorrt.dynamo._defaults import default_device @@ -204,10 +205,27 @@ def set_log_level(parent_logger: Any, level: Any) -> None: Sets the log level to the user provided level. This is used to set debug logging at a global level at entry points of tracing, dynamo and torch_compile compilation. + And set log level for c++ torch trt logger if runtime is available. """ if parent_logger: parent_logger.setLevel(level) + if ENABLED_FEATURES.torch_tensorrt_runtime: + if level == logging.DEBUG: + log_level = trt.ILogger.Severity.VERBOSE + elif level == logging.INFO: + log_level = trt.ILogger.Severity.INFO + elif level == logging.WARNING: + log_level = trt.ILogger.Severity.WARNING + elif level == logging.ERROR: + log_level = trt.ILogger.Severity.ERROR + elif level == logging.CRITICAL: + log_level = trt.ILogger.Severity.INTERNAL_ERROR + else: + raise AssertionError(f"{level} is not valid log level") + + torch.ops.tensorrt.set_logging_level(int(log_level)) + def prepare_inputs( inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],