Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.

Commit 313fb8b

Browse files
kimishpatelbwasti
authored andcommitted
Fixes for the changes in PT API. (#92)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent faecd9d commit 313fb8b

File tree

2 files changed

+18
-7
lines changed

2 files changed

+18
-7
lines changed

torch_tvm/compiler.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
3232
auto optional_ivalue = toIValue(val);
3333
if (optional_ivalue.has_value()) {
3434
if (optional_ivalue.value().isTensor()) {
35+
auto t = optional_ivalue.value().toTensor();
3536
val->inferTypeFrom(optional_ivalue.value().toTensor());
3637
} else {
3738
auto expr = convertToRelay(optional_ivalue.value(), ctx)
@@ -45,15 +46,23 @@ tvm::relay::Var TVMCompiler::convertToRelay(Value* val, TVMContext ctx) {
4546
if (val->isCompleteTensor()) {
4647
// Ensure if complete tensor has device type then it is CPU
4748
// otherwise it is assume to be CPU.
48-
auto pt_t = val->type()->cast<CompleteTensorType>();
49-
auto device_type = pt_t->device();
49+
auto pt_t = val->type()->cast<ProfiledTensorType>();
50+
TORCH_INTERNAL_ASSERT(pt_t);
51+
auto optional_device_type = pt_t->device();
52+
TORCH_INTERNAL_ASSERT(optional_device_type);
53+
auto device_type = optional_device_type.value();
5054
AT_CHECK(device_type == at::DeviceType::CPU,
5155
"Expected CPU device type but got:", device_type);
5256
tvm::Array<tvm::relay::IndexExpr> sizes;
53-
for (const auto& size : pt_t->sizes()) {
54-
sizes.push_back(tvm::relay::IndexExpr(static_cast<int32_t>(size)));
57+
const auto& varying_sizes = pt_t->sizes();
58+
for (const auto& optional_size : varying_sizes.sizes()) {
59+
TORCH_INTERNAL_ASSERT(optional_size);
60+
sizes.push_back(tvm::relay::IndexExpr(
61+
static_cast<int32_t>(optional_size.value())));
5562
}
56-
at::ScalarType pt_type = pt_t->scalarType();
63+
auto optional_dtype = pt_t->scalarType();
64+
TORCH_INTERNAL_ASSERT(optional_dtype);
65+
at::ScalarType pt_type = optional_dtype.value();
5766
auto t = tvm::relay::TensorTypeNode::make(sizes, scalarTypeToTVMType(pt_type));
5867
auto v = tvm::relay::VarNode::make(
5968
val->debugName() +

torch_tvm/operators.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,9 +506,11 @@ RegisterTVMOperator reg({
506506
{Symbol::fromQualString("aten::linear"),
507507
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
508508
Value* input = node->input(0);
509-
auto d_tensor = input->type()->cast<DimensionedTensorType>();
509+
auto d_tensor = input->type()->cast<ProfiledTensorType>();
510510
if (d_tensor) {
511-
int64_t n_dim = d_tensor->dim();
511+
auto optional_n_dim = d_tensor->dim();
512+
TORCH_INTERNAL_ASSERT(optional_n_dim);
513+
int64_t n_dim = optional_n_dim.value();
512514
TORCH_CHECK(n_dim == 2,
513515
"WARNING: relay does not support dense operation on inputs more than 2 dim");
514516
}

0 commit comments

Comments
 (0)