diff --git a/torch_tvm/compiler.cpp b/torch_tvm/compiler.cpp index 63b3271..98a6ea5 100644 --- a/torch_tvm/compiler.cpp +++ b/torch_tvm/compiler.cpp @@ -10,24 +10,36 @@ using namespace torch::jit; tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) { auto optional_ivalue = toIValue(val); - if (optional_ivalue.has_value()) { + + tvm::Array sizes; + if (auto ptt = val->type()->cast()) + { + + auto csizes = ptt->sizes().concrete_sizes(); + TORCH_INTERNAL_ASSERT(csizes.has_value()); + for (const auto& size : *csizes) + { + sizes.push_back(HalideIR::Expr(static_cast(size))); + } + } else if (optional_ivalue.has_value()) { + // TODO: inferTypeFrom should eventually create ProfiledTensorTypes val->inferTypeFrom(optional_ivalue.value().toTensor()); - } - if (val->isCompleteTensor()) { - auto pt_t = val->type()->cast(); - tvm::Array sizes; + auto pt_t = val->type()->expect(); for (const auto& size : pt_t->sizes()) { sizes.push_back(HalideIR::Expr(static_cast(size))); } - // TODO: support non-float tensors - auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32)); - auto v = tvm::relay::VarNode::make( - val->debugName() + - std::to_string(reinterpret_cast(val)), - t); - return v; } - AT_ASSERT(0); + else { + TORCH_INTERNAL_ASSERT(0); + } + + // TODO: support non-float tensors + auto t = tvm::relay::TensorTypeNode::make(sizes, ::tvm::Float(32)); + auto v = tvm::relay::VarNode::make( + val->debugName() + + std::to_string(reinterpret_cast(val)), + t); + return v; } tvm::relay::Expr TVMCompiler::convertToRelay( @@ -95,9 +107,8 @@ tvm::relay::Function TVMCompiler::convertToRelay( std::vector* input_values) { std::unordered_map value_map; tvm::Array input_vars; - for (const auto& input : subgraph->inputs()) { - AT_ASSERT(input->isCompleteTensor()); + TORCH_INTERNAL_ASSERT(input->type()->cast()); auto v = convertToRelay(input, ctx); input_vars.push_back(v); if (input_values) { @@ -222,29 +233,40 @@ void TVMCompiler::run(Stack& stack) { value_to_ivalue[value_input] = inputs[i]; } - CompleteArgumentSpec spec{false, ArrayRef(inputs)}; - - if (cache_.find(spec) == cache_.end()) { + if (!cache_ || (cache_ && (*cache_).invalid)) { for (auto& kv : value_to_ivalue) { - kv.first->inferTypeFrom(kv.second.toTensor()); + // TODO: convince Fuser to NOT create TVMCompilationGroups + // if ANY of subgraph inputs weren't profiled + TORCH_INTERNAL_ASSERT(kv.first->type()->cast()); } // bail out mechanism: try to convert to Relay, if it fails to convert the // graph by any reason(i.e. op difference), depend on the user preference, // either throw or fall back to the JIT interpreter for execution + cache_ = TVMObject {}; tvm::relay::Function tvm_func; try { - tvm_func = convertToRelay(subgraph_, ctx_, &cache_[spec].input_values); + tvm_func = convertToRelay(subgraph_, ctx_, &(*cache_).input_values); + // we compiled the subgraph successfully + (*cache_).invalid = false; } catch (const std::exception& e) { + (*cache_).invalid = true; if (strict_) { AT_ERROR( "Pytorch TVM: fail to convert to relay, exception: ", e.what()); } + LOG(WARNING) - << "Pytorch TVM: fail to convert to relay, falling back to JIT for execution, exception: " + << "Pytorch TVM: fail to convert to relay, exception: " << e.what() << "\n"; + } + + if ((*cache_).invalid) + { + LOG(WARNING) << "Falling back to JIT"; InterpreterState(Code(subgraph_)).run(stack); return; } + auto build_f = build_mod_.GetFunction("build", false); auto json_f = build_mod_.GetFunction("get_graph_json", false); auto mod_f = build_mod_.GetFunction("get_module", false); @@ -255,38 +277,41 @@ void TVMCompiler::run(Stack& stack) { tvm::runtime::Module mod = mod_f(); auto pfr = tvm::runtime::Registry::Get("tvm.graph_runtime.create"); AT_ASSERT(pfr); + tvm::runtime::Module run_mod = (*pfr)(json, mod, (int)ctx_.device_type, (int)ctx_.device_id); - cache_[spec].set_input = run_mod.GetFunction("set_input_zero_copy", false); - cache_[spec].kernel = run_mod.GetFunction("run", false); - cache_[spec].get_output = run_mod.GetFunction("get_output", false); + (*cache_).set_input = run_mod.GetFunction("set_input_zero_copy", false); + (*cache_).kernel = run_mod.GetFunction("run", false); + (*cache_).get_output = run_mod.GetFunction("get_output", false); auto get_num_outputs = run_mod.GetFunction("get_num_outputs", false); int n = get_num_outputs(); AT_CHECK( subgraph_->outputs().size() == n, "Compiled subgraph with mismatching num outputs"); + } - for (auto i = 0; i < cache_[spec].input_values.size(); ++i) { - auto* value = cache_[spec].input_values[i]; + // setting arguments + for (auto i = 0; i < (*cache_).input_values.size(); ++i) { + auto* value = (*cache_).input_values[i]; if (!value_to_ivalue.count(value)) { auto optional_ivalue = toIValue(value); AT_ASSERT(optional_ivalue.has_value()); value_to_ivalue[value] = optional_ivalue.value(); } - auto ivalue = value_to_ivalue.at(cache_[spec].input_values[i]); + auto ivalue = value_to_ivalue.at((*cache_).input_values[i]); auto tensor = ivalue.toTensor().to(at::kFloat); auto dl_tensor = at::toDLPack(tensor); - cache_[spec].set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor)); + (*cache_).set_input(i, tvm::runtime::NDArray::FromDLPack(dl_tensor)); } - cache_[spec].kernel(); + (*cache_).kernel(); // clean the stack and add outputs to the stack drop(stack, num_inputs); int i = 0; for (const auto& output : subgraph_->outputs()) { - tvm::runtime::NDArray ret_val = cache_[spec].get_output(i); + tvm::runtime::NDArray ret_val = (*cache_).get_output(i); auto dl_tensor = ret_val.ToDLPack(); auto tensor = at::fromDLPack(dl_tensor); auto var = torch::autograd::make_variable(tensor); diff --git a/torch_tvm/compiler.h b/torch_tvm/compiler.h index efdfb2f..1f63edf 100644 --- a/torch_tvm/compiler.h +++ b/torch_tvm/compiler.h @@ -15,6 +15,7 @@ struct TVMObject { tvm::PackedFunc get_output; // Map input indices to values in the subgraph std::vector input_values; + bool invalid = true; }; struct TVMCompiler { @@ -29,7 +30,7 @@ struct TVMCompiler { private: std::shared_ptr subgraph_; - std::unordered_map cache_; + c10::optional cache_; TVMContext ctx_; int opt_level_; bool strict_; diff --git a/torch_tvm/register.cpp b/torch_tvm/register.cpp index a483359..c8a969e 100644 --- a/torch_tvm/register.cpp +++ b/torch_tvm/register.cpp @@ -5,6 +5,7 @@ #include #include #include +#include #include "compiler.h" #include "operators.h" @@ -35,6 +36,7 @@ PYBIND11_MODULE(_torch_tvm, m) { RegisterOperators op({Operator( tvm_sym, [](const Node* node) { + GRAPH_DUMP("A graph passed to TVMCompiler\n", node->g(attr::Subgraph)); auto cc = std::make_shared( node, opt_level, strict, device_type, device, host); return [cc](Stack& stack) {