Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 55 additions & 30 deletions torch_tvm/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,36 @@ using namespace torch::jit;

tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
auto optional_ivalue = toIValue(val);
if (optional_ivalue.has_value()) {

tvm::Array<HalideIR::Expr> sizes;
if (auto ptt = val->type()->cast<ProfiledTensorType>())
{

auto csizes = ptt->sizes().concrete_sizes();
TORCH_INTERNAL_ASSERT(csizes.has_value());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we handle the case when this isn't true?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we shouldn't just create a TVMCompGroup in this case. Even though, we can compile at a runtime if we get a workload where shapes change all the time it will be pretty wasteful. I'm looking into changing Fuser to not fuse things that don't have ProfiledTensorTypes

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if it's alternating between 20 batch size shapes but run 1M times it's probably worth still compiling each shape

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hopefully, this will be handled via bailouts. We specialize for shape_set1. Then if we see another frequent set, we will specialize for that one as well and so on.

for (const auto& size : *csizes)
{
sizes.push_back(HalideIR::Expr(static_cast<int32_t>(size)));
}
} else if (optional_ivalue.has_value()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

curious in which case we will not having a value type that is not a ProfiledTensorType since we all switched to profiled graph executor?

// TODO: inferTypeFrom should eventually create ProfiledTensorTypes
val->inferTypeFrom(optional_ivalue.value().toTensor());
}
if (val->isCompleteTensor()) {
auto pt_t = val->type()->cast<CompleteTensorType>();
tvm::Array<HalideIR::Expr> sizes;
auto pt_t = val->type()->expect<CompleteTensorType>();
for (const auto& size : pt_t->sizes()) {
sizes.push_back(HalideIR::Expr(static_cast<int32_t>(size)));
}
// TODO: support non-float tensors
auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32));
auto v = tvm::relay::VarNode::make(
val->debugName() +
std::to_string(reinterpret_cast<std::uintptr_t>(val)),
t);
return v;
}
AT_ASSERT(0);
else {
TORCH_INTERNAL_ASSERT(0);
}

// TODO: support non-float tensors
auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32));
auto v = tvm::relay::VarNode::make(
val->debugName() +
std::to_string(reinterpret_cast<std::uintptr_t>(val)),
t);
return v;
}

tvm::relay::Expr TVMCompiler::convertToRelay(
Expand Down Expand Up @@ -95,9 +107,8 @@ tvm::relay::Function TVMCompiler::convertToRelay(
std::vector<Value*>* input_values) {
std::unordered_map<Value*, tvm::relay::Expr> value_map;
tvm::Array<tvm::relay::Var> input_vars;

for (const auto& input : subgraph->inputs()) {
AT_ASSERT(input->isCompleteTensor());
TORCH_INTERNAL_ASSERT(input->type()->cast<ProfiledTensorType>());
auto v = convertToRelay(input, ctx);
input_vars.push_back(v);
if (input_values) {
Expand Down Expand Up @@ -222,29 +233,40 @@ void TVMCompiler::run(Stack& stack) {
value_to_ivalue[value_input] = inputs[i];
}

CompleteArgumentSpec spec{false, ArrayRef<IValue>(inputs)};

if (cache_.find(spec) == cache_.end()) {
if (!cache_ || (cache_ && (*cache_).invalid)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if cache_ is an optional, can we get rid of the invalid attribute somehow? it's pretty confusing

for (auto& kv : value_to_ivalue) {
kv.first->inferTypeFrom(kv.second.toTensor());
// TODO: convince Fuser to NOT create TVMCompilationGroups
// if ANY of subgraph inputs weren't profiled
TORCH_INTERNAL_ASSERT(kv.first->type()->cast<ProfiledTensorType>());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this, as it is done on line 111

}
// bail out mechanism: try to convert to Relay, if it fails to convert the
// graph by any reason(i.e. op difference), depend on the user preference,
// either throw or fall back to the JIT interpreter for execution
cache_ = TVMObject {};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

at this point, if we run a graph twice with different sizes will we still get TVM compiled code each time?
is the logic for that moved up into the profiled executor?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup, we will bail out, profile again and generate a graph with a TVMCompGroups for different shapes.

tvm::relay::Function tvm_func;
try {
tvm_func = convertToRelay(subgraph_, ctx_, &cache_[spec].input_values);
tvm_func = convertToRelay(subgraph_, ctx_, &(*cache_).input_values);
// we compiled the subgraph successfully
(*cache_).invalid = false;
} catch (const std::exception& e) {
(*cache_).invalid = true;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we failing back to JIT, this means some operators are not converted successfully due to operator semantic mismatch and other behaviors, this invalid flag will let the compiler to re-run the conversion everytime for the same inputs and it will always fail, so I am not sure if that flag would be necessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly! we don't want to re-run if we already know it's going to fail! Re-running compilations might be pretty expensive

if (strict_) {
AT_ERROR(
"Pytorch TVM: fail to convert to relay, exception: ", e.what());
}

LOG(WARNING)
<< "Pytorch TVM: fail to convert to relay, falling back to JIT for execution, exception: "
<< "Pytorch TVM: fail to convert to relay, exception: "
<< e.what() << "\n";
}

if ((*cache_).invalid)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can this block be merged back into block on line 260?

{
LOG(WARNING) << "Falling back to JIT";
InterpreterState(Code(subgraph_)).run(stack);
return;
}

auto build_f = build_mod_.GetFunction("build", false);
auto json_f = build_mod_.GetFunction("get_graph_json", false);
auto mod_f = build_mod_.GetFunction("get_module", false);
Expand All @@ -255,38 +277,41 @@ void TVMCompiler::run(Stack& stack) {
tvm::runtime::Module mod = mod_f();
auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create");
AT_ASSERT(pfr);

tvm::runtime::Module run_mod =
(*pfr)(json, mod, (int)ctx_.device_type, (int)ctx_.device_id);
cache_[spec].set_input = run_mod.GetFunction("set_input_zero_copy", false);
cache_[spec].kernel = run_mod.GetFunction("run", false);
cache_[spec].get_output = run_mod.GetFunction("get_output", false);
(*cache_).set_input = run_mod.GetFunction("set_input_zero_copy", false);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

->?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arrgh, thanks! I forgot to switch it

(*cache_).kernel = run_mod.GetFunction("run", false);
(*cache_).get_output = run_mod.GetFunction("get_output", false);
auto get_num_outputs = run_mod.GetFunction("get_num_outputs", false);
int n = get_num_outputs();
AT_CHECK(
subgraph_->outputs().size() == n,
"Compiled subgraph with mismatching num outputs");

}

for (auto i = 0; i < cache_[spec].input_values.size(); ++i) {
auto* value = cache_[spec].input_values[i];
// setting arguments
for (auto i = 0; i < (*cache_).input_values.size(); ++i) {
auto* value = (*cache_).input_values[i];
if (!value_to_ivalue.count(value)) {
auto optional_ivalue = toIValue(value);
AT_ASSERT(optional_ivalue.has_value());
value_to_ivalue[value] = optional_ivalue.value();
}
auto ivalue = value_to_ivalue.at(cache_[spec].input_values[i]);
auto ivalue = value_to_ivalue.at((*cache_).input_values[i]);
auto tensor = ivalue.toTensor().to(at::kFloat);
auto dl_tensor = at::toDLPack(tensor);
cache_[spec].set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor));
(*cache_).set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor));
}

cache_[spec].kernel();
(*cache_).kernel();

// clean the stack and add outputs to the stack
drop(stack, num_inputs);
int i = 0;
for (const auto& output : subgraph_->outputs()) {
tvm::runtime::NDArray ret_val = cache_[spec].get_output(i);
tvm::runtime::NDArray ret_val = (*cache_).get_output(i);
auto dl_tensor = ret_val.ToDLPack();
auto tensor = at::fromDLPack(dl_tensor);
auto var = torch::autograd::make_variable(tensor);
Expand Down
3 changes: 2 additions & 1 deletion torch_tvm/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ struct TVMObject {
tvm::PackedFunc get_output;
// Map input indices to values in the subgraph
std::vector<torch::jit::Value*> input_values;
bool invalid = true;
};

struct TVMCompiler {
Expand All @@ -29,7 +30,7 @@ struct TVMCompiler {

private:
std::shared_ptr<torch::jit::Graph> subgraph_;
std::unordered_map<torch::jit::CompleteArgumentSpec, TVMObject> cache_;
c10::optional<TVMObject> cache_;
TVMContext ctx_;
int opt_level_;
bool strict_;
Expand Down
2 changes: 2 additions & 0 deletions torch_tvm/register.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <torch/csrc/jit/pass_manager.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/jit/jit_log.h>

#include "compiler.h"
#include "operators.h"
Expand Down Expand Up @@ -35,6 +36,7 @@ PYBIND11_MODULE(_torch_tvm, m) {
RegisterOperators op({Operator(
tvm_sym,
[](const Node* node) {
GRAPH_DUMP("A graph passed to TVMCompiler\n", node->g(attr::Subgraph));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we move this to a different diff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

auto cc = std::make_shared<TVMCompiler>(
node, opt_level, strict, device_type, device, host);
return [cc](Stack& stack) {
Expand Down