Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust cpp torch trt logging level with compiler option #3181

Merged
merged 5 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ TRTEngine::TRTEngine(
}
num_io = std::make_pair(inputs_size, outputs);
}

#ifndef NDEBUG
this->enable_profiling();
if (util::logging::get_logger().get_reportable_log_level() == util::logging::LogLevel::kDEBUG) {
keehyuna marked this conversation as resolved.
Show resolved Hide resolved
this->enable_profiling();
}
#endif
LOG_DEBUG(*this);
}
Expand Down
5 changes: 1 addition & 4 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,8 @@ struct TRTEngine : torch::CustomClassHolder {
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);

void set_profiling_paths();
#ifndef NDEBUG
bool profile_execution = true;
#else
bool profile_execution = false;
#endif

std::string device_profile_path;
std::string input_profile_path;
std::string output_profile_path;
Expand Down
2 changes: 1 addition & 1 deletion core/util/logging/TorchTRTLogger.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ namespace {

TorchTRTLogger& get_global_logger() {
#ifndef NDEBUG
static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kDEBUG, true);
static TorchTRTLogger global_logger("[Torch-TensorRT - Debug Build] - ", LogLevel::kINFO, true);
keehyuna marked this conversation as resolved.
Show resolved Hide resolved
#else
static TorchTRTLogger global_logger("[Torch-TensorRT] - ", LogLevel::kWARNING, false);
#endif
Expand Down
16 changes: 16 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,26 @@ def set_log_level(parent_logger: Any, level: Any) -> None:
Sets the log level to the user provided level.
This is used to set debug logging at a global level
at entry points of tracing, dynamo and torch_compile compilation.
It also set log level for c++ torch trt logger
"""
if parent_logger:
parent_logger.setLevel(level)

if level == logging.DEBUG:
log_level = trt.ILogger.Severity.VERBOSE
elif level == logging.INFO:
log_level = trt.ILogger.Severity.INFO
elif level == logging.WARNING:
log_level = trt.ILogger.Severity.WARNING
elif level == logging.ERROR:
log_level = trt.ILogger.Severity.ERROR
elif level == logging.CRITICAL:
log_level = trt.ILogger.Severity.INTERNAL_ERROR
else:
raise AssertionError(f"{level} is valid log level")

torch.ops.tensorrt.set_logging_level(int(log_level))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens in the case that it is a python only build? You might want to use enabled features to gate access to this code to only when the C++ runtime is available.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right, I see the problem. Updated the change.

AttributeError: '_OpNamespace' 'tensorrt' object has no attribute 'set_logging_level'



def prepare_inputs(
inputs: Input | torch.Tensor | Sequence[Any] | Dict[Any, Any],
Expand Down
Loading