-
Notifications
You must be signed in to change notification settings - Fork 64
[WIP] Switch to using ProfiledTensorType #68
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,24 +10,36 @@ using namespace torch::jit; | |
|
|
||
| tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) { | ||
| auto optional_ivalue = toIValue(val); | ||
| if (optional_ivalue.has_value()) { | ||
|
|
||
| tvm::Array<HalideIR::Expr> sizes; | ||
| if (auto ptt = val->type()->cast<ProfiledTensorType>()) | ||
| { | ||
|
|
||
| auto csizes = ptt->sizes().concrete_sizes(); | ||
| TORCH_INTERNAL_ASSERT(csizes.has_value()); | ||
| for (const auto& size : *csizes) | ||
| { | ||
| sizes.push_back(HalideIR::Expr(static_cast<int32_t>(size))); | ||
| } | ||
| } else if (optional_ivalue.has_value()) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. curious in which case we will not having a value type that is not a |
||
| // TODO: inferTypeFrom should eventually create ProfiledTensorTypes | ||
| val->inferTypeFrom(optional_ivalue.value().toTensor()); | ||
| } | ||
| if (val->isCompleteTensor()) { | ||
| auto pt_t = val->type()->cast<CompleteTensorType>(); | ||
| tvm::Array<HalideIR::Expr> sizes; | ||
| auto pt_t = val->type()->expect<CompleteTensorType>(); | ||
| for (const auto& size : pt_t->sizes()) { | ||
| sizes.push_back(HalideIR::Expr(static_cast<int32_t>(size))); | ||
| } | ||
| // TODO: support non-float tensors | ||
| auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32)); | ||
| auto v = tvm::relay::VarNode::make( | ||
| val->debugName() + | ||
| std::to_string(reinterpret_cast<std::uintptr_t>(val)), | ||
| t); | ||
| return v; | ||
| } | ||
| AT_ASSERT(0); | ||
| else { | ||
| TORCH_INTERNAL_ASSERT(0); | ||
| } | ||
|
|
||
| // TODO: support non-float tensors | ||
| auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32)); | ||
| auto v = tvm::relay::VarNode::make( | ||
| val->debugName() + | ||
| std::to_string(reinterpret_cast<std::uintptr_t>(val)), | ||
| t); | ||
| return v; | ||
| } | ||
|
|
||
| tvm::relay::Expr TVMCompiler::convertToRelay( | ||
|
|
@@ -95,9 +107,8 @@ tvm::relay::Function TVMCompiler::convertToRelay( | |
| std::vector<Value*>* input_values) { | ||
| std::unordered_map<Value*, tvm::relay::Expr> value_map; | ||
| tvm::Array<tvm::relay::Var> input_vars; | ||
|
|
||
| for (const auto& input : subgraph->inputs()) { | ||
| AT_ASSERT(input->isCompleteTensor()); | ||
| TORCH_INTERNAL_ASSERT(input->type()->cast<ProfiledTensorType>()); | ||
| auto v = convertToRelay(input, ctx); | ||
| input_vars.push_back(v); | ||
| if (input_values) { | ||
|
|
@@ -222,29 +233,40 @@ void TVMCompiler::run(Stack& stack) { | |
| value_to_ivalue[value_input] = inputs[i]; | ||
| } | ||
|
|
||
| CompleteArgumentSpec spec{false, ArrayRef<IValue>(inputs)}; | ||
|
|
||
| if (cache_.find(spec) == cache_.end()) { | ||
| if (!cache_ || (cache_ && (*cache_).invalid)) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if |
||
| for (auto& kv : value_to_ivalue) { | ||
| kv.first->inferTypeFrom(kv.second.toTensor()); | ||
| // TODO: convince Fuser to NOT create TVMCompilationGroups | ||
| // if ANY of subgraph inputs weren't profiled | ||
| TORCH_INTERNAL_ASSERT(kv.first->type()->cast<ProfiledTensorType>()); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we need this, as it is done on line 111 |
||
| } | ||
| // bail out mechanism: try to convert to Relay, if it fails to convert the | ||
| // graph by any reason(i.e. op difference), depend on the user preference, | ||
| // either throw or fall back to the JIT interpreter for execution | ||
| cache_ = TVMObject {}; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. at this point, if we run a graph twice with different sizes will we still get TVM compiled code each time?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yup, we will bail out, profile again and generate a graph with a TVMCompGroups for different shapes. |
||
| tvm::relay::Function tvm_func; | ||
| try { | ||
| tvm_func = convertToRelay(subgraph_, ctx_, &cache_[spec].input_values); | ||
| tvm_func = convertToRelay(subgraph_, ctx_, &(*cache_).input_values); | ||
| // we compiled the subgraph successfully | ||
| (*cache_).invalid = false; | ||
| } catch (const std::exception& e) { | ||
| (*cache_).invalid = true; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. When we failing back to JIT, this means some operators are not converted successfully due to operator semantic mismatch and other behaviors, this invalid flag will let the compiler to re-run the conversion everytime for the same inputs and it will always fail, so I am not sure if that flag would be necessary.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. exactly! we don't want to re-run if we already know it's going to fail! Re-running compilations might be pretty expensive |
||
| if (strict_) { | ||
| AT_ERROR( | ||
| "Pytorch TVM: fail to convert to relay, exception: ", e.what()); | ||
| } | ||
|
|
||
| LOG(WARNING) | ||
| << "Pytorch TVM: fail to convert to relay, falling back to JIT for execution, exception: " | ||
| << "Pytorch TVM: fail to convert to relay, exception: " | ||
| << e.what() << "\n"; | ||
| } | ||
|
|
||
| if ((*cache_).invalid) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can this block be merged back into block on line 260? |
||
| { | ||
| LOG(WARNING) << "Falling back to JIT"; | ||
| InterpreterState(Code(subgraph_)).run(stack); | ||
| return; | ||
| } | ||
|
|
||
| auto build_f = build_mod_.GetFunction("build", false); | ||
| auto json_f = build_mod_.GetFunction("get_graph_json", false); | ||
| auto mod_f = build_mod_.GetFunction("get_module", false); | ||
|
|
@@ -255,38 +277,41 @@ void TVMCompiler::run(Stack& stack) { | |
| tvm::runtime::Module mod = mod_f(); | ||
| auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); | ||
| AT_ASSERT(pfr); | ||
|
|
||
| tvm::runtime::Module run_mod = | ||
| (*pfr)(json, mod, (int)ctx_.device_type, (int)ctx_.device_id); | ||
| cache_[spec].set_input = run_mod.GetFunction("set_input_zero_copy", false); | ||
| cache_[spec].kernel = run_mod.GetFunction("run", false); | ||
| cache_[spec].get_output = run_mod.GetFunction("get_output", false); | ||
| (*cache_).set_input = run_mod.GetFunction("set_input_zero_copy", false); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. arrgh, thanks! I forgot to switch it |
||
| (*cache_).kernel = run_mod.GetFunction("run", false); | ||
| (*cache_).get_output = run_mod.GetFunction("get_output", false); | ||
| auto get_num_outputs = run_mod.GetFunction("get_num_outputs", false); | ||
| int n = get_num_outputs(); | ||
| AT_CHECK( | ||
| subgraph_->outputs().size() == n, | ||
| "Compiled subgraph with mismatching num outputs"); | ||
|
|
||
| } | ||
|
|
||
| for (auto i = 0; i < cache_[spec].input_values.size(); ++i) { | ||
| auto* value = cache_[spec].input_values[i]; | ||
| // setting arguments | ||
| for (auto i = 0; i < (*cache_).input_values.size(); ++i) { | ||
| auto* value = (*cache_).input_values[i]; | ||
| if (!value_to_ivalue.count(value)) { | ||
| auto optional_ivalue = toIValue(value); | ||
| AT_ASSERT(optional_ivalue.has_value()); | ||
| value_to_ivalue[value] = optional_ivalue.value(); | ||
| } | ||
| auto ivalue = value_to_ivalue.at(cache_[spec].input_values[i]); | ||
| auto ivalue = value_to_ivalue.at((*cache_).input_values[i]); | ||
| auto tensor = ivalue.toTensor().to(at::kFloat); | ||
| auto dl_tensor = at::toDLPack(tensor); | ||
| cache_[spec].set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor)); | ||
| (*cache_).set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor)); | ||
| } | ||
|
|
||
| cache_[spec].kernel(); | ||
| (*cache_).kernel(); | ||
|
|
||
| // clean the stack and add outputs to the stack | ||
| drop(stack, num_inputs); | ||
| int i = 0; | ||
| for (const auto& output : subgraph_->outputs()) { | ||
| tvm::runtime::NDArray ret_val = cache_[spec].get_output(i); | ||
| tvm::runtime::NDArray ret_val = (*cache_).get_output(i); | ||
| auto dl_tensor = ret_val.ToDLPack(); | ||
| auto tensor = at::fromDLPack(dl_tensor); | ||
| auto var = torch::autograd::make_variable(tensor); | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -5,6 +5,7 @@ | |
| #include <torch/csrc/jit/pass_manager.h> | ||
| #include <torch/csrc/jit/passes/graph_fuser.h> | ||
| #include <torch/csrc/jit/pybind_utils.h> | ||
| #include <torch/csrc/jit/jit_log.h> | ||
|
|
||
| #include "compiler.h" | ||
| #include "operators.h" | ||
|
|
@@ -35,6 +36,7 @@ PYBIND11_MODULE(_torch_tvm, m) { | |
| RegisterOperators op({Operator( | ||
| tvm_sym, | ||
| [](const Node* node) { | ||
| GRAPH_DUMP("A graph passed to TVMCompiler\n", node->g(attr::Subgraph)); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we move this to a different diff?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sure |
||
| auto cc = std::make_shared<TVMCompiler>( | ||
| node, opt_level, strict, device_type, device, host); | ||
| return [cc](Stack& stack) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we handle the case when this isn't true?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we shouldn't just create a TVMCompGroup in this case. Even though, we can compile at a runtime if we get a workload where shapes change all the time it will be pretty wasteful. I'm looking into changing Fuser to not fuse things that don't have
ProfiledTensorTypesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if it's alternating between 20 batch size shapes but run 1M times it's probably worth still compiling each shape
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hopefully, this will be handled via bailouts. We specialize for shape_set1. Then if we see another frequent set, we will specialize for that one as well and so on.