Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
Add support for the ONNX Runtime Eager Mode backend (pytorch#58248)
Browse files Browse the repository at this point in the history
Summary:
This PR implements the necessary hooks/stubs/enums/etc for complete ONNX Runtime (ORT) Eager Mode integration. The actual extension will live out of tree at https://github.com/pytorch/ort.

We have been [working on this at Microsoft](https://github.com/microsoft/onnxruntime-pytorch/tree/eager-ort/torch_onnxruntime) for the last few months, and are finally ready to contribute the PyTorch core changes upstream (nothing major or exciting, just the usual boilerplate for adding new backends).

The ORT backend will allow us to ferry [almost] all torch ops into granular ONNX kernels that ORT will eagerly execute against any devices it supports (therefore, we only need a single ORT backend from a PyTorch perspective).

Pull Request resolved: pytorch#58248

Reviewed By: astaff

Differential Revision: D30344992

Pulled By: albanD

fbshipit-source-id: 69082b32121246340d686e16653626114b7714b2
  • Loading branch information
abock authored and facebook-github-bot committed Aug 20, 2021
1 parent b95ce15 commit c78ab28
Show file tree
Hide file tree
Showing 38 changed files with 236 additions and 120 deletions.
8 changes: 8 additions & 0 deletions aten/src/ATen/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/core/LegacyTypeDispatch.h>
#include <ATen/detail/CUDAHooksInterface.h>
#include <ATen/detail/HIPHooksInterface.h>
#include <ATen/detail/ORTHooksInterface.h>
#include <c10/util/Exception.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/core/QEngine.h>
Expand Down Expand Up @@ -79,6 +80,9 @@ class TORCH_API Context {
static bool hasMLC() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::MLC);
}
static bool hasORT() {
return c10::impl::hasDeviceGuardImpl(at::DeviceType::ORT);
}
// defined in header so that getNonVariableType has ability to inline
// call_once check. getNonVariableType is called fairly frequently
THCState* lazyInitCUDA() {
Expand Down Expand Up @@ -292,6 +296,10 @@ static inline bool hasMLC() {
return globalContext().hasMLC();
}

static inline bool hasORT() {
return globalContext().hasORT();
}

// Despite its name, this function returns the number of *CUDA* GPUs.
static inline size_t getNumGPUs() {
// WARNING: DO NOT ADD LOGIC TO HANDLE OTHER DEVICE TYPES TO THIS
Expand Down
4 changes: 4 additions & 0 deletions aten/src/ATen/Version.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,10 @@ std::string show_config() {
ss << detail::getCUDAHooks().showConfig();
}

if (hasORT()) {
ss << detail::getORTHooks().showConfig();
}

ss << " - Build settings: ";
for (const auto& pair : caffe2::GetBuildOptions()) {
if (!pair.second.empty()) {
Expand Down
1 change: 1 addition & 0 deletions aten/src/ATen/core/aten_interned_strings.h
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ _(aten, is_complex) \
_(aten, is_contiguous) \
_(aten, is_cuda) \
_(aten, is_mlc) \
_(aten, is_ort) \
_(aten, is_distributed) \
_(aten, is_floating_point) \
_(aten, is_inference) \
Expand Down
4 changes: 2 additions & 2 deletions aten/src/ATen/core/op_registration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ There’s four main use cases
* You’re writing a new operator that isn’t supposed to be part of the public PyTorch API.
* You’re writing a new operator but don’t want to change the core pytorch code base, say you’re developing a shared library with operators.
* You’re writing a C++ extension for PyTorch or you’re using inline c++ in your .py model files.
* You’re writing a backend library like XLA or MSNPU that adds new kernels to all operators defined in `native_functions.yaml`.
* You’re writing a backend library like XLA or ORT that adds new kernels to all operators defined in `native_functions.yaml`.

For these use cases, the custom operator API is the better solution.

### What is the price for using the custom operator API instead of `native_functions.yaml`?

If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/MSNPU example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.
If you’re just using the custom operator API to add new kernels for existing operators (e.g. the XLA/ORT example above), then you’re fine and don’t pay any price. If, however, you define a new operator purely using the custom op API, i.e. your operator never shows up in `native_functions.yaml`, then you need to be aware of a few caveats.

* It will not get a C++ API generated. There will not be `Tensor::your_op()` methods or `at::your_op()` functions to call your operator.
* The API for calling the operator from Python looks a little bit different. It needs to be called through `torch.ops.your_op()` instead of `torch._C`.
Expand Down
31 changes: 31 additions & 0 deletions aten/src/ATen/detail/ORTHooksInterface.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#include <ATen/detail/ORTHooksInterface.h>

#include <c10/util/Exception.h>

#include <cstddef>
#include <memory>
#include <mutex>

namespace at {
namespace detail {

// See getCUDAHooks for some more commentary
const ORTHooksInterface& getORTHooks() {
static std::unique_ptr<ORTHooksInterface> ort_hooks;
static std::once_flag once;
std::call_once(once, [] {
ort_hooks = ORTHooksRegistry()->Create("ORTHooks", {});
if (!ort_hooks) {
ort_hooks =
// NOLINTNEXTLINE(modernize-make-unique)
std::unique_ptr<ORTHooksInterface>(new ORTHooksInterface());
}
});
return *ort_hooks;
}
} // namespace detail

// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
C10_DEFINE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs)

} // namespace at
36 changes: 36 additions & 0 deletions aten/src/ATen/detail/ORTHooksInterface.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#pragma once

#include <c10/util/Exception.h>
#include <c10/util/Registry.h>

constexpr const char* ORT_HELP =
" You need to 'import torch_ort' to use the 'ort' device in PyTorch. "
"The 'torch_ort' module is provided by the ONNX Runtime itself "
"(https://onnxruntime.ai).";

// NB: Class must live in `at` due to limitations of Registry.h.
namespace at {

struct TORCH_API ORTHooksInterface {
// This should never actually be implemented, but it is used to
// squelch -Werror=non-virtual-dtor
virtual ~ORTHooksInterface() {}

virtual std::string showConfig() const {
TORCH_CHECK(false, "Cannot query detailed ORT version information.", ORT_HELP);
}
};

// NB: dummy argument to suppress "ISO C++11 requires at least one argument
// for the "..." in a variadic macro"
struct TORCH_API ORTHooksArgs {};

C10_DECLARE_REGISTRY(ORTHooksRegistry, ORTHooksInterface, ORTHooksArgs);
#define REGISTER_ORT_HOOKS(clsname) \
C10_REGISTER_CLASS(ORTHooksRegistry, clsname, clsname)

namespace detail {
TORCH_API const ORTHooksInterface& getORTHooks();
} // namespace detail

} // namespace at
6 changes: 6 additions & 0 deletions aten/src/ATen/templates/TensorBody.h
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,12 @@ class TORCH_API Tensor {
return impl_->is_mlc();
}

/// Returns if a `Tensor` is ort tensor.
bool is_ort() const {
// NB: this is not a native function to avoid dispatching overhead.
return impl_->is_ort();
}

/// Returns if a `Tensor` is vulkan tensor.
bool is_vulkan() const {
// NB: this is not a native function to avoid dispatching overhead.
Expand Down
23 changes: 14 additions & 9 deletions aten/src/ATen/test/extension_backend_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@

#include <torch/csrc/jit/runtime/operator.h>

// NB. These tests use the ORT dispatch key to test backend dispatching
// machinery, but these tests are not specific to ORT at all. The ORT
// backend is fully out-of-tree, so it's safe to use this key for
// in-tree tests.

using namespace at;

static int test_int;
Expand All @@ -17,16 +22,16 @@ Tensor empty_override(IntArrayRef size, c10::optional<ScalarType> dtype, c10::op
Storage(
Storage::use_byte_size_t(),
0,
at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)),
at::DataPtr(nullptr, Device(DeviceType::ORT, 1)),
nullptr,
false),
DispatchKey::MSNPU,
DispatchKey::ORT,
caffe2::TypeMeta::Make<float>());
return Tensor(std::move(tensor_impl));
}

Tensor add_override(const Tensor & a, const Tensor & b , const Scalar& c) {
auto out = empty({5, 5}, at::kMSNPU); // Don't return self as-is
auto out = empty({5, 5}, at::kORT); // Don't return self as-is
test_int = 2;
return out;
}
Expand All @@ -42,28 +47,28 @@ Tensor empty_strided_override(
return empty_override(size, dtype, layout, device, pin_memory, c10::nullopt);
}

TORCH_LIBRARY_IMPL(aten, MSNPU, m) {
TORCH_LIBRARY_IMPL(aten, ORT, m) {
m.impl("aten::empty.memory_format", empty_override);
m.impl("aten::empty_strided", empty_strided_override);
m.impl("aten::add.Tensor", add_override);
}

TEST(BackendExtensionTest, TestRegisterOp) {
Tensor a = empty({5, 5}, at::kMSNPU);
ASSERT_EQ(a.device().type(), at::kMSNPU);
Tensor a = empty({5, 5}, at::kORT);
ASSERT_EQ(a.device().type(), at::kORT);
ASSERT_EQ(a.device().index(), 1);
ASSERT_EQ(a.dtype(), caffe2::TypeMeta::Make<float>());
ASSERT_EQ(test_int, 1);

Tensor b = empty_like(a, at::kMSNPU);
ASSERT_EQ(b.device().type(), at::kMSNPU);
Tensor b = empty_like(a, at::kORT);
ASSERT_EQ(b.device().type(), at::kORT);
ASSERT_EQ(b.device().index(), 1);
ASSERT_EQ(b.dtype(), caffe2::TypeMeta::Make<float>());

add(a, b);
ASSERT_EQ(test_int, 2);

// Ensure that non-MSNPU operator still works
// Ensure that non-ORT operator still works
Tensor d = empty({5, 5}, at::kCPU);
ASSERT_EQ(d.device().type(), at::kCPU);
}
18 changes: 9 additions & 9 deletions c10/core/Backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class Backend {
SparseHIP,
SparseVE,
SparseXPU,
MSNPU,
ORT,
XLA,
Vulkan,
Metal,
Expand All @@ -66,8 +66,8 @@ static inline Backend dispatchKeyToBackend(DispatchKey t) {
return Backend::VE;
} else if (t == DispatchKey::FPGA) {
return Backend::FPGA;
} else if (t == DispatchKey::MSNPU) {
return Backend::MSNPU;
} else if (t == DispatchKey::ORT) {
return Backend::ORT;
} else if (t == DispatchKey::XLA || t == DispatchKey::AutogradXLA) {
return Backend::XLA;
} else if (t == DispatchKey::Lazy || t == DispatchKey::AutogradLazy) {
Expand Down Expand Up @@ -123,8 +123,8 @@ static inline DispatchKey backendToDispatchKey(Backend b) {
return DispatchKey::VE;
case Backend::FPGA:
return DispatchKey::FPGA;
case Backend::MSNPU:
return DispatchKey::MSNPU;
case Backend::ORT:
return DispatchKey::ORT;
case Backend::XLA:
return DispatchKey::XLA;
case Backend::Lazy:
Expand Down Expand Up @@ -178,8 +178,8 @@ static inline DeviceType backendToDeviceType(Backend b) {
return DeviceType::VE;
case Backend::FPGA:
return DeviceType::FPGA;
case Backend::MSNPU:
return DeviceType::MSNPU;
case Backend::ORT:
return DeviceType::ORT;
case Backend::XLA:
return DeviceType::XLA;
case Backend::Lazy:
Expand Down Expand Up @@ -235,8 +235,8 @@ static inline const char* toString(Backend b) {
return "FPGA";
case Backend::XPU:
return "XPU";
case Backend::MSNPU:
return "MSNPU";
case Backend::ORT:
return "ORT";
case Backend::XLA:
return "XLA";
case Backend::Lazy:
Expand Down
4 changes: 2 additions & 2 deletions c10/core/Device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ DeviceType parse_type(const std::string& device_string) {
{"hip", DeviceType::HIP},
{"ve", DeviceType::VE},
{"fpga", DeviceType::FPGA},
{"msnpu", DeviceType::MSNPU},
{"ort", DeviceType::ORT},
{"xla", DeviceType::XLA},
{"lazy", DeviceType::Lazy},
{"vulkan", DeviceType::Vulkan},
Expand All @@ -47,7 +47,7 @@ DeviceType parse_type(const std::string& device_string) {
}
TORCH_CHECK(
false,
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, msnpu, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
"Expected one of cpu, cuda, xpu, mkldnn, opengl, opencl, ideep, hip, ve, ort, mlc, xla, lazy, vulkan, meta, hpu device type at start of device string: ",
device_string);
}
enum DeviceStringParsingState { START, INDEX_START, INDEX_REST, ERROR };
Expand Down
6 changes: 3 additions & 3 deletions c10/core/DeviceType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ std::string DeviceTypeName(DeviceType d, bool lower_case) {
return lower_case ? "ve" : "VE";
case DeviceType::FPGA:
return lower_case ? "fpga" : "FPGA";
case DeviceType::MSNPU:
return lower_case ? "msnpu" : "MSNPU";
case DeviceType::ORT:
return lower_case ? "ort" : "ORT";
case DeviceType::XLA:
return lower_case ? "xla" : "XLA";
case DeviceType::Lazy:
Expand Down Expand Up @@ -75,7 +75,7 @@ bool isValidDeviceType(DeviceType d) {
case DeviceType::HIP:
case DeviceType::VE:
case DeviceType::FPGA:
case DeviceType::MSNPU:
case DeviceType::ORT:
case DeviceType::XLA:
case DeviceType::Lazy:
case DeviceType::MLC:
Expand Down
4 changes: 2 additions & 2 deletions c10/core/DeviceType.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ enum class DeviceType : int8_t {
IDEEP = 5, // IDEEP.
HIP = 6, // AMD HIP
FPGA = 7, // FPGA
MSNPU = 8, // MSNPU
ORT = 8, // ONNX Runtime / Microsoft
XLA = 9, // XLA / TPU
Vulkan = 10, // Vulkan
Metal = 11, // Metal
Expand All @@ -42,7 +42,7 @@ constexpr DeviceType kCPU = DeviceType::CPU;
constexpr DeviceType kCUDA = DeviceType::CUDA;
constexpr DeviceType kHIP = DeviceType::HIP;
constexpr DeviceType kFPGA = DeviceType::FPGA;
constexpr DeviceType kMSNPU = DeviceType::MSNPU;
constexpr DeviceType kORT = DeviceType::ORT;
constexpr DeviceType kXLA = DeviceType::XLA;
constexpr DeviceType kMLC = DeviceType::MLC;
constexpr DeviceType kMeta = DeviceType::Meta;
Expand Down
4 changes: 2 additions & 2 deletions c10/core/DispatchKey.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ const char* toString(DispatchKey t) {
return "FPGA";
case DispatchKey::XPU:
return "XPU";
case DispatchKey::MSNPU:
return "MSNPU";
case DispatchKey::ORT:
return "ORT";
case DispatchKey::XLA:
return "XLA";
case DispatchKey::Lazy:
Expand Down
13 changes: 10 additions & 3 deletions c10/core/DispatchKey.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,15 @@ enum class DispatchKey : uint8_t {
// CUDA]
FPGA, // Xilinx support lives out of tree at
// https://gitlab.com/pytorch-complex/vitis_kernels
MSNPU, // unused externally, but tested at
// test/cpp_extensions/msnpu_extension.cpp

// ONNX Runtime, lives out of tree at https://github.com/pytorch/ort and
// https://github.com/microsoft/onnxruntime, and is also used to test general
// backend/extension machinery in the core. cf:
// - test/cpp_extensions/ort_extension.cpp
// - test/test_torch.py
// - aten/src/ATen/test/extension_backend_test.cpp
ORT,

XLA, // lives out of tree at https://github.com/pytorch/xla
MLC, // lives out of tree at https://github.com/pytorch/MLCompute
Vulkan,
Expand Down Expand Up @@ -114,7 +121,7 @@ enum class DispatchKey : uint8_t {

// Here are reserved backends for user-defined backends, see Note [Private use
// DispatchKey]
// To see some example about how to use this, check out MSNPU
// To see some example about how to use this, check out ORT
PrivateUse1,
PrivateUse2,
PrivateUse3,
Expand Down
1 change: 1 addition & 0 deletions c10/core/DispatchKeySet.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ constexpr DispatchKeySet backend_dispatch_keyset = autogradother_backends |
DispatchKey::PrivateUse3,
DispatchKey::MLC,
DispatchKey::HPU,
DispatchKey::ORT,
DispatchKey::Meta,
});

Expand Down
2 changes: 1 addition & 1 deletion c10/core/DispatchKeySet.h
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ constexpr DispatchKeySet autogradother_backends = DispatchKeySet(
{DispatchKey::HIP,
DispatchKey::VE,
DispatchKey::FPGA,
DispatchKey::MSNPU,
DispatchKey::ORT,
DispatchKey::Vulkan,
DispatchKey::Metal,
DispatchKey::QuantizedCPU,
Expand Down
4 changes: 4 additions & 0 deletions c10/core/TensorImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -873,6 +873,10 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target {
return key_set_.has(DispatchKey::MLC);
}

bool is_ort() const {
return key_set_.has(DispatchKey::ORT);
}

// TODO: remove this once we don't automatically enabled Autograd dispatch
// keys
// in TensorImpl constructor.
Expand Down
Loading

0 comments on commit c78ab28

Please sign in to comment.