|
11 | 11 | from torch._subclasses.fake_tensor import FakeTensor
|
12 | 12 | from torch_tensorrt._Device import Device
|
13 | 13 | from torch_tensorrt._enums import dtype
|
| 14 | +from torch_tensorrt._features import ENABLED_FEATURES |
14 | 15 | from torch_tensorrt._Input import Input
|
15 | 16 | from torch_tensorrt.dynamo import _defaults
|
16 | 17 | from torch_tensorrt.dynamo._defaults import default_device
|
@@ -204,10 +205,27 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
|
204 | 205 | Sets the log level to the user provided level.
|
205 | 206 | This is used to set debug logging at a global level
|
206 | 207 | at entry points of tracing, dynamo and torch_compile compilation.
|
| 208 | + And set log level for c++ torch trt logger if runtime is available. |
207 | 209 | """
|
208 | 210 | if parent_logger:
|
209 | 211 | parent_logger.setLevel(level)
|
210 | 212 |
|
| 213 | + if ENABLED_FEATURES.torch_tensorrt_runtime: |
| 214 | + if level == logging.DEBUG: |
| 215 | + log_level = trt.ILogger.Severity.VERBOSE |
| 216 | + elif level == logging.INFO: |
| 217 | + log_level = trt.ILogger.Severity.INFO |
| 218 | + elif level == logging.WARNING: |
| 219 | + log_level = trt.ILogger.Severity.WARNING |
| 220 | + elif level == logging.ERROR: |
| 221 | + log_level = trt.ILogger.Severity.ERROR |
| 222 | + elif level == logging.CRITICAL: |
| 223 | + log_level = trt.ILogger.Severity.INTERNAL_ERROR |
| 224 | + else: |
| 225 | + raise AssertionError(f"{level} is not valid log level") |
| 226 | + |
| 227 | + torch.ops.tensorrt.set_logging_level(int(log_level)) |
| 228 | + |
211 | 229 |
|
212 | 230 | def prepare_inputs(
|
213 | 231 | inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
|
|
0 commit comments