Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.

Commit 6b8674e

Browse files
authored
Added support for dropout in TVM which essentially is just skipping it. (#87)
Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 313fb8b commit 6b8674e

File tree

1 file changed

+9
-0
lines changed

1 file changed

+9
-0
lines changed

torch_tvm/operators.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,15 @@ RegisterTVMOperator reg({
541541
{ inputs[0] },
542542
tvm::Attrs(softmax_attrs));
543543
}},
544+
{Symbol::fromQualString("aten::dropout"),
545+
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
546+
TORCH_CHECK(inputs.size() == 3, "Expected number of inputs 3, got ",
547+
inputs.size());
548+
auto train = relayToConstant<bool>(inputs[2]);
549+
TORCH_CHECK(!train, "Only inference mode dropout is supported"
550+
" in torch tvm");
551+
return inputs[0];
552+
}},
544553
});
545554

546555
bool isSupported(Node* node) {

0 commit comments

Comments
 (0)