Skip to content

Commit

Permalink
chore: rebase with main
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 committed Sep 24, 2024
2 parents 0de0b16 + 43eb560 commit a90191d
Show file tree
Hide file tree
Showing 331 changed files with 15,767 additions and 2,459 deletions.
6 changes: 6 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,12 @@ repos:
rev: v1.22.9
hooks:
- id: typos
- repo: https://github.com/astral-sh/uv-pre-commit
# uv version.
rev: 0.4.10
hooks:
# Update the uv lockfile
- id: uv-lock
- repo: local
hooks:
- id: dont-commit-upstream
Expand Down
23 changes: 19 additions & 4 deletions core/lowering/passes/unpack_scaled_dot_product_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ namespace passes {
// https://pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph) {
std::string sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale)
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%out: Tensor = aten::scaled_dot_product_attention(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa)
return (%out))IR";

std::string unpacked_sdpa_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%1 : int = prim::Constant[value=-1]()
%2 : int = prim::Constant[value=-2]()
Expand All @@ -33,7 +33,7 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
return(%out))IR";

std::string unpacked_sdpa_attn_biased_pattern = R"IR(
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale):
graph(%query, %key, %value, %attn_mask, %dropout_p, %is_causal, %scale, %enable_gqa):
%none : NoneType = prim::Constant()
%0 : int = prim::Constant[value=1]()
%1 : int = prim::Constant[value=-1]()
Expand Down Expand Up @@ -69,6 +69,16 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
if (attn_mask_node->kind() != at::prim::Constant || !attn_mask_node->mustBeNone()) {
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant) {
LOG_WARNING(
"Could not unpack scaled_dot_product_attention with non constant enable_gqa: " << *enable_gqa_node);
return false;
}
if (enable_gqa_node->i(at::attr::value) == 1) {
LOG_WARNING("Could not unpack scaled_dot_product_attention with enable_gqa = True: " << *enable_gqa_node);
return false;
}
return true;
});

Expand All @@ -83,6 +93,11 @@ void UnpackScaledDotProductAttention(std::shared_ptr<torch::jit::Graph>& graph)
// messages already written in first pass, do not write again
return false;
}
auto enable_gqa_node = match.anchor->inputs().at(7)->node();
if (enable_gqa_node->kind() != at::prim::Constant || enable_gqa_node->i(at::attr::value) == 1) {
// messages already written in first pass, do not write again
return false;
}
return true;
});
LOG_GRAPH("Post unpack scaled_dot_product_attention: " << *graph);
Expand Down
18 changes: 15 additions & 3 deletions core/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
name = "runtime",
srcs = [
"DeviceList.cpp",
"Platform.cpp",
"RTDevice.cpp",
"TRTEngine.cpp",
"TRTEngineProfiler.cpp",
Expand All @@ -29,6 +30,7 @@ cc_library(
"runtime.cpp",
],
hdrs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand All @@ -41,16 +43,26 @@ cc_library(
"//core/plugins:torch_tensorrt_plugins",
"//core/util:prelude",
] + select({
":windows": ["@tensorrt_win//:nvinfer", "@libtorch_win//:libtorch"],
":use_pre_cxx11_abi": ["@tensorrt//:nvinfer", "@libtorch_pre_cxx11_abi//:libtorch"],
"//conditions:default": ["@tensorrt//:nvinfer", "@libtorch"],
":use_pre_cxx11_abi": [
"@libtorch_pre_cxx11_abi//:libtorch",
"@tensorrt//:nvinfer",
],
":windows": [
"@libtorch_win//:libtorch",
"@tensorrt_win//:nvinfer",
],
"//conditions:default": [
"@libtorch",
"@tensorrt//:nvinfer",
],
}),
alwayslink = True,
)

pkg_tar(
name = "include",
srcs = [
"Platform.h",
"RTDevice.h",
"TRTEngine.h",
"TRTEngineProfiler.h",
Expand Down
2 changes: 2 additions & 0 deletions core/runtime/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ set(CXX_SRCS
"${CMAKE_CURRENT_SOURCE_DIR}/execute_engine.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/register_jit_hooks.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.cpp"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.cpp"
)

set(HEADER_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/RTDevice.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngine.h"
"${CMAKE_CURRENT_SOURCE_DIR}/TRTEngineProfiler.h"
"${CMAKE_CURRENT_SOURCE_DIR}/runtime.h"
"${CMAKE_CURRENT_SOURCE_DIR}/Platform.h"
)

target_sources(${lib_name}
Expand Down
101 changes: 101 additions & 0 deletions core/runtime/Platform.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/prelude.h"

namespace torch_tensorrt {
namespace core {
namespace runtime {

namespace {
const std::unordered_map<std::string, Platform::PlatformEnum>& get_name_to_platform_map() {
static const std::unordered_map<std::string, Platform::PlatformEnum> name_to_platform_map = {
{"linux_aarch64", Platform::PlatformEnum::kLINUX_AARCH64},
{"linux_x86_64", Platform::PlatformEnum::kLINUX_X86_64},
{"windows_x86_64", Platform::PlatformEnum::kWIN_X86_64},
{"unknown", Platform::PlatformEnum::kUNKNOWN},
};
return name_to_platform_map;
}

const std::unordered_map<Platform::PlatformEnum, std::string>& _get_platform_name_map() {
static const std::unordered_map<Platform::PlatformEnum, std::string> platform_name_map = {
{Platform::PlatformEnum::kLINUX_AARCH64, "linux_aarch64"},
{Platform::PlatformEnum::kLINUX_X86_64, "linux_x86_64"},
{Platform::PlatformEnum::kWIN_X86_64, "windows_x86_64"},
{Platform::PlatformEnum::kUNKNOWN, "unknown"}};
return platform_name_map;
}
} // namespace

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map() {
return _get_platform_name_map();
}

Platform::Platform() : _platform{Platform::PlatformEnum::kUNKNOWN} {}

Platform::Platform(Platform::PlatformEnum val) : _platform{val} {}

Platform::Platform(const std::string& platform_str) {
auto name_map = get_name_to_platform_map();
auto it = name_map.find(platform_str);
if (it != name_map.end()) {
_platform = it->second;
} else {
LOG_WARNING("Unknown platform " << platform_str);
_platform = Platform::PlatformEnum::kUNKNOWN;
}
}

std::string Platform::serialize() const {
auto name_map = get_platform_name_map();
auto it = name_map.find(_platform);
if (it != name_map.end()) {
return it->second;
} else {
LOG_WARNING("Attempted to serialized unknown platform tag");
return std::string("unknown");
}
}

Platform& Platform::operator=(const Platform& other) {
_platform = other._platform;
return (*this);
}

bool operator==(const Platform& lhs, const Platform& rhs) {
return lhs._platform == rhs._platform;
}

std::ostream& operator<<(std::ostream& os, const Platform& platform) {
os << platform.serialize();
return os;
}

Platform get_current_platform() {
#if defined(__linux__) || defined(__gnu_linux__)
#if defined(__aarch64__)
return Platform(Platform::PlatformEnum::kLINUX_AARCH64);
#elif defined(__amd64__) || defined(__x86_64__)
return Platform(Platform::PlatformEnum::kLINUX_X86_64);
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
#elif defined(_WIN32) || defined(_WIN64)
#if defined(_M_AMD64) || defined(_M_X64)
return Platform(Platform::PlatformEnum::kWIN_X86_64);
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
#else
return Platform(Platform::PlatformEnum::kUNKNOWN);
#endif
}

bool is_supported_on_current_platform(Platform target) {
// Space for more complicated platform support calculations later
return target == get_current_platform();
}

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
35 changes: 35 additions & 0 deletions core/runtime/Platform.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#pragma once
#include <string>
#include <unordered_map>

namespace torch_tensorrt {
namespace core {
namespace runtime {

struct Platform {
typedef enum {
kLINUX_X86_64 = 0,
kLINUX_AARCH64,
kWIN_X86_64,
kUNKNOWN,
} PlatformEnum;

PlatformEnum _platform = Platform::kUNKNOWN;

Platform();
Platform(PlatformEnum val);
Platform(const std::string& platform_str);
std::string serialize() const;
Platform& operator=(const Platform& other);

friend std::ostream& operator<<(std::ostream& os, const Platform& device);
friend bool operator==(const Platform& lhs, const Platform& rhs);
};

const std::unordered_map<Platform::PlatformEnum, std::string>& get_platform_name_map();
Platform get_current_platform();
bool is_supported_on_current_platform(Platform target);

} // namespace runtime
} // namespace core
} // namespace torch_tensorrt
17 changes: 15 additions & 2 deletions core/runtime/TRTEngine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata)
: TRTEngine(
Expand All @@ -42,6 +43,7 @@ TRTEngine::TRTEngine(
cuda_device,
_in_binding_names,
_out_binding_names,
target_platform,
hardware_compatible,
serialized_metadata) {}

Expand All @@ -52,6 +54,7 @@ TRTEngine::TRTEngine(std::vector<std::string> serialized_info)
RTDevice(serialized_info[DEVICE_IDX]),
split(serialized_info[INPUT_BINDING_NAMES_IDX], BINDING_DELIM),
split(serialized_info[OUTPUT_BINDING_NAMES_IDX], BINDING_DELIM),
Platform(serialized_info[TARGET_PLATFORM_IDX]),
static_cast<bool>(std::stoi(serialized_info[HW_COMPATIBLE_IDX])),
serialized_info[SERIALIZED_METADATA_IDX]) {}

Expand All @@ -61,12 +64,22 @@ TRTEngine::TRTEngine(
const RTDevice& cuda_device,
const std::vector<std::string>& _in_binding_names,
const std::vector<std::string>& _out_binding_names,
const Platform& target_platform,
bool hardware_compatible,
const std::string& serialized_metadata) {
TORCHTRT_CHECK(
is_supported_on_current_platform(target_platform),
"This engine was not built to run on this platform (built for: " << target_platform << ", current platform: "
<< get_current_platform() << ")");
this->target_platform = target_platform;

this->cudagraph_mempool_id = at::cuda::graph_pool_handle();

this->hardware_compatible = hardware_compatible;
this->serialized_metadata = serialized_metadata;
auto most_compatible_device = get_most_compatible_device(cuda_device, RTDevice(), hardware_compatible);
TORCHTRT_CHECK(most_compatible_device, "No compatible device was found for instantiating TensorRT engine");

this->serialized_metadata = serialized_metadata;
device_info = most_compatible_device.value();
multi_gpu_device_check();
set_rt_device(device_info);
Expand Down Expand Up @@ -196,7 +209,6 @@ TRTEngine::TRTEngine(
}

TRTEngine::~TRTEngine() {
cudagraph.reset();
trt_engine_profiler.reset();
exec_ctx.reset();
cuda_engine.reset();
Expand Down Expand Up @@ -276,6 +288,7 @@ std::string TRTEngine::to_str() const {
ss << " ]" << std::endl;
ss << " Device: " << device_info << std::endl;
ss << " Hardware Compatibility: " << (hardware_compatible ? "Enabled" : "Disabled") << std::endl;
ss << " Target Platform: " << target_platform << std::endl;
// clang-format on
return ss.str();
}
Expand Down
7 changes: 7 additions & 0 deletions core/runtime/TRTEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,30 @@ struct TRTEngine : torch::CustomClassHolder {
bool hardware_compatible = false; // Whether the engine was compiled in hardware compatible mode
std::string serialized_metadata; // This is a base64 encoded pkl object used to store metadata such as settings used
// in compilation
Platform target_platform;

~TRTEngine();
TRTEngine(
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine(std::vector<std::string> serialized_info);

TRTEngine(
const std::string& mod_name,
const std::string& serialized_engine,
const RTDevice& cuda_device,
const std::vector<std::string>& in_binding_names,
const std::vector<std::string>& out_binding_names,
const Platform& target_platform = get_current_platform(),
bool hardware_compatible = false,
const std::string& serialized_metadata = "");

TRTEngine& operator=(const TRTEngine& other);
std::string to_str() const;
static void verify_serialization_fmt(const std::vector<std::string>& serialized_info);
Expand All @@ -75,6 +81,7 @@ struct TRTEngine : torch::CustomClassHolder {
std::vector<at::Tensor> input_buffers = {};
std::vector<at::Tensor> output_buffers = {};
std::string shape_key;
at::cuda::MempoolId_t cudagraph_mempool_id;

// TODO: Implement a call method
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
Expand Down
2 changes: 1 addition & 1 deletion core/runtime/execute_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,7 @@ std::vector<at::Tensor> execute_engine(std::vector<at::Tensor> inputs, c10::intr
if (need_cudagraphs_record) {
// If cudagraphs needs to record a graph, capture the enqueueV3 call in a graph
c10::cuda::CUDAStream recording_stream = compiled_engine->engine_stream;
compiled_engine->cudagraph.capture_begin();
compiled_engine->cudagraph.capture_begin(compiled_engine->cudagraph_mempool_id);
compiled_engine->exec_ctx->enqueueV3(recording_stream);
compiled_engine->cudagraph.capture_end();

Expand Down
Loading

0 comments on commit a90191d

Please sign in to comment.