Skip to content

Commit

Permalink
optimize for inference, num threads control, docs
Browse files Browse the repository at this point in the history
  • Loading branch information
SiLiKhon committed Aug 27, 2020
1 parent 6ea8a53 commit adaa4e2
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 4 deletions.
10 changes: 9 additions & 1 deletion model_export/dump_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import tensorflow as tf
from tensorflow.python.framework import convert_to_constants
from tensorflow.core.framework import attr_value_pb2
from tensorflow.python.tools import optimize_for_inference_lib

from . import tf2xla_pb2

Expand Down Expand Up @@ -35,8 +36,15 @@ def to_save(x):
path = str(output_file.parent)
filename = output_file.name

tf.io.write_graph(
optimized_graph = optimize_for_inference_lib.optimize_for_inference(
constant_graph.graph.as_graph_def(),
[i.op.name for i in constant_graph.inputs],
[o.op.name for o in constant_graph.outputs],
tf.float32.as_datatype_enum
)

tf.io.write_graph(
optimized_graph,
path, filename
)

Expand Down
7 changes: 5 additions & 2 deletions model_export/model_v4/model.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ extern "C" int get_batch_size() {
return BATCH_SIZE;
}

extern "C" void model_init() {
if (!tp) tp = new Eigen::ThreadPool(std::thread::hardware_concurrency());
extern "C" void model_init(int num_threads) {
if (num_threads < 1) {
num_threads = std::thread::hardware_concurrency();
}
if (!tp) tp = new Eigen::ThreadPool(num_threads);
if (!device) device = new Eigen::ThreadPoolDevice(tp, tp->NumThreads());
if (!graph) {
graph = new GRAPH_CLASS;
Expand Down
38 changes: 37 additions & 1 deletion model_export/model_v4/model.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,40 @@
/**
* Get the batch size this model was compiled with.
*
* @return The batch size.
*/
extern "C" int get_batch_size();
extern "C" void model_init();

/**
* Initialize the model.
*
* This function instantiates the model graph, as well as
* Eigen::ThreadPool and Eigen::ThreadPoolDevice to run the graph.
*
* @param num_threads Number of threads for Eigen::ThreadPool. If smaller than 1,
* will deduce automatically with std::thread::hardware_concurrency().
* Defaults to 0.
*/
extern "C" void model_init(int num_threads = 0);

/**
* Run the model.
*
* @param input Pointer to input data. This will be copied to the model's
* memory prior to actually running the model.
* @param output Pointer to the output data. The result will be copied here
* from the model's memory after running the model.
* @param input_size Size of the input buffer. Must be batch_size * 4.
* @param output_size Size of the output buffer. Must be batch_size * 8 * 16.
*
* @return Status.
* -2 = model was not initialized
* -1 = running the graph failed
* 0 = success
*/
extern "C" int model_run(float *input, float *output, int input_size, int output_size);

/**
* De-initialize the model.
*/
extern "C" void model_free();

0 comments on commit adaa4e2

Please sign in to comment.