Skip to content
This repository was archived by the owner on Apr 1, 2021. It is now read-only.

Commit ea2cdd9

Browse files
committed
Map batch matmul from PT to TVM.
Also added test. Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent 313fb8b commit ea2cdd9

File tree

2 files changed

+34
-0
lines changed

2 files changed

+34
-0
lines changed

test/test_operators.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -326,6 +326,20 @@ def linear_no_bias(input, weight):
326326
ref_out_no_bias, tvm_out_no_bias = self.runBoth(linear_no_bias, input, weight)
327327
assert torch.allclose(ref_out_no_bias, tvm_out_no_bias, rtol=0.01, atol=0.01)
328328

329+
@TVMTest.given(
330+
shape=TVMTest.rand_shape(rank=3, min_dim=4),
331+
out_features=TVMTest.rand_int(3, 6),
332+
)
333+
def test_bmm(self, shape, out_features):
334+
input = torch.rand(shape)
335+
weight = torch.rand(shape[0], shape[-1], out_features)
336+
337+
def bmm(input, weight):
338+
return torch.bmm(input, weight) + 2.0
339+
340+
ref_out, tvm_out = self.runBoth(bmm, input, weight)
341+
assert torch.allclose(ref_out, tvm_out, rtol=0.01, atol=0.01)
342+
329343
@TVMTest.given(
330344
shape=TVMTest.rand_shape(rank=2, min_dim=4),
331345
)

torch_tvm/operators.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,26 @@ RegisterTVMOperator reg({
503503
tvm::relay::CallNode::make(op, {inputs[0]}, tvm::Attrs(attrs), {});
504504
return out;
505505
}},
506+
{Symbol::fromQualString("aten::bmm"),
507+
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
508+
TORCH_INTERNAL_ASSERT(inputs.size()==2);
509+
510+
auto transpose_attrs = tvm::make_node<tvm::relay::TransposeAttrs>();
511+
auto& axes = transpose_attrs->axes;
512+
axes.push_back(0);
513+
axes.push_back(2);
514+
axes.push_back(1);
515+
516+
auto transposed_weight = tvm::relay::CallNode::make(
517+
tvm::relay::Op::Get("transpose"),
518+
{inputs[1]}, tvm::Attrs(transpose_attrs), {});
519+
520+
auto out = tvm::relay::CallNode::make(
521+
tvm::relay::Op::Get("nn.batch_matmul"),
522+
{inputs[0], transposed_weight}, {}, {});
523+
524+
return out;
525+
}},
506526
{Symbol::fromQualString("aten::linear"),
507527
[](Node* node, tvm::Array<tvm::relay::Expr> inputs) {
508528
Value* input = node->input(0);

0 commit comments

Comments
 (0)