Skip to content

Commit

Permalink
Update onnx app to work with newer versions of protobuf (#7879)
Browse files Browse the repository at this point in the history
and to work on mac
  • Loading branch information
abadams authored Oct 6, 2023
1 parent 120e5fd commit 24a64f8
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 17 deletions.
10 changes: 6 additions & 4 deletions apps/onnx/Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Default to the adams 2019 autoscheduler
AUTOSCHEDULER ?= adams2019
include ../support/Makefile.inc

ifneq (,$(findstring -m32,$(CXX) $(CC) $(CCFLAGS) $(CXXFLAGS)))
Expand All @@ -13,7 +15,7 @@ ifdef PROTOC
CXXFLAGS += -DGOOGLE_PROTOBUF_NO_RTTI -Wno-sign-compare -Wno-unused-but-set-variable
CXXFLAGS += -I$(dir $(PROTOC))../include
LDFLAGS += -L$(dir $(PROTOC))../lib
LDFLAGS += -lprotobuf-lite
LDFLAGS += $(shell pkg-config protobuf-lite --libs)

# Copy onnx.proto to $(BIN)
$(BIN)/%/onnx/onnx.proto:
Expand Down Expand Up @@ -83,18 +85,18 @@ test: build model_test
LD_LIBRARY_PATH=$(BIN) $(BIN)/$(HL_TARGET)/onnx_converter_generator_test

PYTHON ?= python3
PYBIND11_CFLAGS = $(shell $(PYTHON)-config --includes) -frtti
PYBIND11_CFLAGS = $(shell pybind11-config --includes) -frtti -std=c++17
ifeq ($(UNAME), Darwin)
# Keep OS X builds from complaining about missing libpython symbols
PYBIND11_CFLAGS += -undefined dynamic_lookup
endif
PY_EXT = $(shell $(PYTHON)-config --extension-suffix)
PY_MODEL_EXT = model_cpp$(PY_EXT)
PYCXXFLAGS = $(CXXFLAGS) $(PYBIND11_CFLAGS) -Wno-deprecated-register
PYCXXFLAGS = $(PYBIND11_CFLAGS) $(CXXFLAGS) -Wno-deprecated-register

# Python extension for HalideModel
$(BIN)/%/$(PY_MODEL_EXT): model.cpp $(BIN)/%/oclib.a
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -Wl,--no-as-needed -lautoschedule_adams2019 -Wl,--as-needed -o $@ $(LDFLAGS)
$(CXX) $(PYCXXFLAGS) -Wall -shared -fPIC -I$(BIN)/$* $^ $(LIBHALIDE_LDFLAGS) -o $@ $(LDFLAGS)


model_test: $(BIN)/$(HL_TARGET)/$(PY_MODEL_EXT)
Expand Down
26 changes: 13 additions & 13 deletions apps/onnx/onnx_converter_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

#define EXPECT_EQ(a, b) \
if ((a) != (b)) { \
exit(1); \
exit(1); \
}
#define EXPECT_NEAR(a, b, c) \
if (std::abs((a) - (b)) > (c)) { \
exit(1); \
exit(1); \
}

static void test_abs() {
Expand All @@ -30,7 +30,7 @@ static void test_abs() {

Node converted = convert_node(abs_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 1> output = converted.outputs[0].rep.realize({200});
for (int i = 0; i < 200; ++i) {
EXPECT_EQ(output(i), std::abs(input(i)));
Expand All @@ -56,7 +56,7 @@ static void test_activation_function() {

Node converted = convert_node(relu_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 1> output = converted.outputs[0].rep.realize({200});
for (int i = 0; i < 200; ++i) {
EXPECT_EQ(output(i), std::max(input(i), 0.0f));
Expand Down Expand Up @@ -85,7 +85,7 @@ static void test_cast() {

Node converted = convert_node(cast_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 1> output = converted.outputs[0].rep.realize({200});
for (int i = 0; i < 200; ++i) {
EXPECT_EQ(output(i), static_cast<float>(input(i)));
Expand Down Expand Up @@ -117,7 +117,7 @@ static void test_add() {

Node converted = convert_node(add_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 1> output = converted.outputs[0].rep.realize({200});
for (int i = 0; i < 200; ++i) {
EXPECT_NEAR(output(i), in1(i) + in2(i), 1e-6);
Expand All @@ -144,7 +144,7 @@ static void test_constant() {

Node converted = convert_node(add_node, {});

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 2> output = converted.outputs[0].rep.realize({3, 7});
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 7; ++j) {
Expand Down Expand Up @@ -186,7 +186,7 @@ static void test_gemm() {
node_inputs[2].rep(i3, j3) = in3(i3, j3);
Node converted = convert_node(add_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 2> output = converted.outputs[0].rep.realize({32, 64});

for (int i = 0; i < 32; ++i) {
Expand Down Expand Up @@ -239,7 +239,7 @@ static void test_conv() {

Node converted = convert_node(add_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 4> output =
converted.outputs[0].rep.realize(out_shape[trial]);

Expand Down Expand Up @@ -287,7 +287,7 @@ static void test_sum() {

Node converted = convert_node(sum_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 4> output = converted.outputs[0].rep.realize({1, 3, 1, 11});
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 11; ++j) {
Expand Down Expand Up @@ -331,7 +331,7 @@ static void test_where_broadcast() {
node_inputs[2].rep(i, j) = in_y(i, j);

Node converted = convert_node(where_node, node_inputs);
GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 3> output = converted.outputs[0].rep.realize({2, 2, 2});

for (int i = 0; i < 2; ++i) {
Expand Down Expand Up @@ -375,7 +375,7 @@ static void test_concat() {

Node converted = convert_node(concat_node, node_inputs);

GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<float, 2> output = converted.outputs[0].rep.realize({7 + 5, 3});
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 7; ++j) {
Expand Down Expand Up @@ -405,7 +405,7 @@ static void test_constant_fill() {
dtype_attr->set_i(4);

Node converted = convert_node(concat_node, {});
GOOGLE_CHECK_EQ(1, converted.outputs.size());
EXPECT_EQ(1, converted.outputs.size());
Halide::Buffer<uint16_t, 2> output = converted.outputs[0].rep.realize({3, 4});
for (int i = 0; i < 3; ++i) {
for (int j = 0; j < 4; ++j) {
Expand Down

0 comments on commit 24a64f8

Please sign in to comment.