Skip to content
This repository has been archived by the owner on Aug 16, 2024. It is now read-only.

Commit

Permalink
[TensorExpr] Port NNC lowerings to the new registry mechanism. (pytor…
Browse files Browse the repository at this point in the history
…ch#65551)

Summary:
Pull Request resolved: pytorch#65551

Previously we had a big switch on Op kind to decide how to lower a given
JIT operator to NNC. This PR changes this switch to a hash table lookup.

Why? This helps us with at least two things:
1) With this approach we can easily check if we know how to handle a
given node in advance - i.e. we can inspect the entire graph and tell
whether it's possible to compile it or not without actually trying to do
that and dying in the middle. This would allow us to, say, provide
user-friendly error messages in AOT workflow.
2) We can switch to use schema instead of op kind to determine correct
lowering. Unlike op schema, op kind might be ambigous (see e.g. pytorch#64963)
and using it instead of schema can lead to bugs.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D31148926

Pulled By: ZolotukhinM

fbshipit-source-id: ac12684e2126c899426ef5e4cc1e3f70fa01f704
  • Loading branch information
Mikhail Zolotukhin authored and facebook-github-bot committed Oct 1, 2021
1 parent eee9ad0 commit 3a0165d
Show file tree
Hide file tree
Showing 14 changed files with 874 additions and 427 deletions.
8 changes: 5 additions & 3 deletions benchmarks/cpp/tensorexpr/bench_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,7 +385,8 @@ BENCHMARK_DEFINE_F(Reduce1D, Op)(benchmark::State& state) {
const int kChunkSize = 8;

te::BufHandle a("A", {M}, te::kFloat);
te::Tensor b = te::computeSum({a, te::IntList({0}), false}, at::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto loops = nest.getLoopStmtsFor(b);
Expand Down Expand Up @@ -447,7 +448,8 @@ BENCHMARK_REGISTER_F(Reduce2DCol, Torch)
BENCHMARK_DEFINE_F(Reduce2DCol, OpSchedule)(benchmark::State& state) {
constexpr int kCacheSize = 1 << 12;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b = te::computeSum({a, te::IntList({0}), false}, at::kFloat);
te::Tensor b =
te::computeSum({a, te::IntList({0}), false}, {N}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto sch = state.range(2);
Expand Down Expand Up @@ -553,7 +555,7 @@ BENCHMARK_REGISTER_F(Reduce2DRow, Hand)
BENCHMARK_DEFINE_F(Reduce2DRow, OpSchedule)(benchmark::State& state) {
constexpr int kChunkSize = 8;
te::BufHandle a("A", {M, N}, te::kFloat);
te::Tensor b = te::computeSum({a, te::IntList({1}), false}, at::kFloat);
te::Tensor b = te::computeSum({a, te::IntList({1}), false}, {M}, at::kFloat, at::kCPU);
te::LoopNest nest({b});

auto sch = state.range(2);
Expand Down
1 change: 1 addition & 0 deletions test/cpp/tensorexpr/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ set(TENSOREXPR_TEST_SRCS
${TENSOREXPR_TEST_ROOT}/test_kernel.cpp
${TENSOREXPR_TEST_ROOT}/test_loopnest.cpp
${TENSOREXPR_TEST_ROOT}/test_memdependency.cpp
${TENSOREXPR_TEST_ROOT}/test_ops.cpp
${TENSOREXPR_TEST_ROOT}/test_reductions.cpp
${TENSOREXPR_TEST_ROOT}/test_registerizer.cpp
${TENSOREXPR_TEST_ROOT}/test_simplify.cpp
Expand Down
11 changes: 7 additions & 4 deletions test/cpp/tensorexpr/test_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,16 @@ std::unique_ptr<SimpleIREvaluator> compile(
}

TEST(Ops, Sum) {
constexpr int M = 8;
constexpr int N = 16;
std::vector<IntList> testDims = {{0}, {1}, {0, 1}};
for (auto const& dims : testDims) {
constexpr int M = 8;
constexpr int N = 16;
std::vector<std::vector<ExprHandle>> outputShapes = {{N}, {M}, {}};
for (int idx = 0; idx < testDims.size(); idx++) {
const auto& dims = testDims[idx];
const auto& outShape = outputShapes[idx];

BufHandle a("a", {M, N}, kFloat);
Tensor b = computeSum({a, dims, false}, c10::kFloat);
Tensor b = computeSum({a, dims, false}, outShape, c10::kFloat, at::kCPU);
auto cg = compile({a}, {b});

auto at = at::arange(M * N, at::kFloat).view({M, N});
Expand Down
Loading

0 comments on commit 3a0165d

Please sign in to comment.