diff --git a/docs/debugging-optimizing/example-log.md b/docs/debugging-optimizing/example-log.md index 707daaa52f..4ea303865f 100644 --- a/docs/debugging-optimizing/example-log.md +++ b/docs/debugging-optimizing/example-log.md @@ -16,7 +16,7 @@ limitations under the License. # Example Log, Verbose Level = diagnose -The following is an example log with `NM_LOGGING_LEVEL=diagnose` running a super_resolution network, where we only support running 70% of it. Different portions of the log are explained in [Parsing an Example Log](./diagnostics-debugging.md#parsing-an-example-log). +The following is an example log with `NM_LOGGING_LEVEL=diagnose` running a super_resolution network, where we only support running 70% of it. Different portions of the log are explained in [Parsing an Example Log](diagnostics-debugging.md#parsing-an-example-log). ```bash onnx_filename : test-models/cv-resolution/super_resolution/none-bsd300-onnx-repo/model.onnx diff --git a/docs/index.rst b/docs/index.rst index ade7ccf00d..b5eac02105 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -66,6 +66,7 @@ For example, pruning plus quantization can give noticeable improvements in perfo The Deep Sparse product suite builds on top of sparsification enabling you to easily apply the techniques to your datasets and models using recipe-driven approaches. Recipes encode the directions for how to sparsify a model into a simple, easily editable format. + - Download a sparsification recipe and sparsified model from the `SparseZoo `_. - Alternatively, create a recipe for your model using `Sparsify `_. - Apply your recipe with only a few lines of code using `SparseML `_. @@ -121,6 +122,7 @@ Additionally, more information can be found via :caption: Performance debugging-optimizing/index + source/scheduler .. toctree:: :maxdepth: 2 diff --git a/docs/source/multi-stream.png b/docs/source/multi-stream.png new file mode 100644 index 0000000000..769133fb0d Binary files /dev/null and b/docs/source/multi-stream.png differ diff --git a/docs/source/scheduler.md b/docs/source/scheduler.md new file mode 100644 index 0000000000..209fb1d5f0 --- /dev/null +++ b/docs/source/scheduler.md @@ -0,0 +1,55 @@ + + +## Serial or Concurrent Inferences + +Schedulers are special system software which handle the distribution of work across cores in parallel computation. The goal of a good scheduler is to ensure that while work is available, cores aren’t sitting idle. On the contrary, as long as parallel tasks are available, all cores should be kept busy. + +In most use cases, the default scheduler is the preferred choice when running inferences with the DeepSparse Engine. It's highly optimized for minimum per-request latency, using all of the system's resources provided to it on every request it gets. Often, particularly when working with large batch sizes, the scheduler is able to distribute the workload of a single request across as many cores as it's provided. + +![Single-stream scheduling diagram](single-stream.png) + +_Single stream scheduling; requests execute serially by default_ + +However, there are circumstances in which more cores does not imply better performance. If the computation can't be divided up to produce enough parallelism (while maximizing use of the CPU cache), then adding more cores simply adds more compute power with little to apply it to. + +An alternative, "multi-stream" scheduler is provided with the software. In cases where parallelism is low, sending multiple requests simultaneously can more adequately saturate the available cores. In other words, if speedup can't be achieved by adding more cores, then perhaps speedup can be achieved by adding more work. + +If increasing core count doesn't decrease latency, that's a strong indicator that parallelism is low in your particular model/batch-size combination. It may be that total throughput can be increased by making more requests simultaneously. Using the [deepsparse.engine.Scheduler API](https://docs.neuralmagic.com/deepsparse/api/deepsparse.html), the multi-stream scheduler can be selected, and requests made by multiple Python threads will be handled concurrently. + +![Multi-stream scheduling diagram](multi-stream.png) + +_Multi-stream scheduling; requests execute in parallel and may utilize hardware resources better_ + +Whereas the default scheduler will queue up requests made simultaneously and handle them serially, the multi-stream scheduler maintains a set of dropboxes where requests may be deposited and the requesting threads can wait. These dropboxes allow workers to find work from multiple sources when work from a single source would otherwise be scarce, maximizing throughput. When a request is complete, the requesting thread is awakened and returns the results to the caller. + +The most common use cases for the multi-stream scheduler are where parallelism is low with respect to core count, and where requests need to be made asynchronously without time to batch them. Implementing a model server may fit such a scenario and be ideal for using multi-stream scheduling. + +Depending on your engine execution strategy, enable one of these options by running: + +```python +engine = compile_model(model_path, batch_size, num_cores, num_sockets, "single_stream") +``` + +or + +```python +engine = compile_model(model_path, batch_size, num_cores, num_sockets, "multi_stream") +``` + +or pass in the enum value directly, since` "multi_stream" == Scheduler.multi_stream` + +By default, the scheduler will map to a single stream. diff --git a/docs/source/single-stream.png b/docs/source/single-stream.png new file mode 100644 index 0000000000..c324a2045d Binary files /dev/null and b/docs/source/single-stream.png differ diff --git a/examples/flask/README.md b/examples/flask/README.md index 3f6b02e3f4..12e4ea97ff 100644 --- a/examples/flask/README.md +++ b/examples/flask/README.md @@ -55,11 +55,11 @@ python client.py ~/Downloads/resnet18_pruned.onnx ``` Output: ```bash -[ INFO onnx.py: 92 - generate_random_inputs() ] Generating 1 random inputs -[ INFO onnx.py: 102 - generate_random_inputs() ] -- random input #0 of shape = [1, 3, 224, 224] -Sending 1 input tensors to http://0.0.0.0:5543/predict -Recieved response of 2 output tensors: -Round-trip time took 13.4261 milliseconds - output #0: shape (1, 1000) - output #1: shape (1, 1000) +[ INFO onnx.py: 127 - generate_random_inputs() ] -- generating random input #0 of shape = [1, 3, 224, 224] +[ INFO client.py: 152 - main() ] Sending 1 input tensors to http://0.0.0.0:5543/run +[ DEBUG client.py: 102 - _post() ] Sending POST request to http://0.0.0.0:5543/run +[ INFO client.py: 159 - main() ] Round-trip time took 13.3283 milliseconds +[ INFO client.py: 160 - main() ] Received response of 2 output tensors: +[ INFO client.py: 163 - main() ] output #0: shape (1, 1000) +[ INFO client.py: 163 - main() ] output #1: shape (1, 1000) ``` diff --git a/examples/flask/client.py b/examples/flask/client.py index 1a2a204835..70a290aed3 100644 --- a/examples/flask/client.py +++ b/examples/flask/client.py @@ -18,17 +18,18 @@ ########## Command help: -usage: client.py [-h] [-s BATCH_SIZE] [-a ADDRESS] [-p PORT] onnx_filepath +usage: client.py [-h] [-b BATCH_SIZE] [-a ADDRESS] [-p PORT] model_path -Communicate with a Flask server hosting an ONNX model with the -DeepSparse Engine as inference backend. +Communicate with a Flask server hosting an ONNX model with the DeepSparse +Engine as inference backend. positional arguments: - onnx_filepath The full filepath of the ONNX model file + model_path The full filepath of the ONNX model file or SparseZoo + stub of model optional arguments: -h, --help show this help message and exit - -s BATCH_SIZE, --batch_size BATCH_SIZE + -b BATCH_SIZE, --batch-size BATCH_SIZE The batch size to run the analysis for -a ADDRESS, --address ADDRESS The IP address of the hosted model @@ -41,11 +42,65 @@ """ import argparse +import os import time +from typing import Any, Callable, List +import numpy import requests -from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, generate_random_inputs +from deepsparse.utils import ( + arrays_to_bytes, + bytes_to_arrays, + generate_random_inputs, + log_init, +) + + +_LOGGER = log_init(os.path.basename(__file__)) + + +class EngineFlaskClient: + """ + Client object for interacting with HTTP server invoked with `engine_flask_server`. + + :param address: IP address of server to query + :param port: port that the server is running on + :param preprocessing_fn: function to preprocess inputs to the run argument before + sending inputs to the model server. Defaults to the `arrays_to_bytes` function + for serializing lists of numpy arrays + :param preprocessing_fn: function to postprocess outputs from model server + inferences. Defaults to the `bytes_to_arrays` function for de-serializing + lists of numpy arrays + """ + + def __init__( + self, + address: str, + port: str, + preprocessing_fn: Callable[[Any], Any] = arrays_to_bytes, + postprocessing_fn: Callable[[Any], Any] = bytes_to_arrays, + ): + self.url = f"http://{address}:{port}" + self.preprocessing_fn = preprocessing_fn + self.postprocessing_fn = postprocessing_fn + + def run(self, inp: List[numpy.ndarray]) -> List[numpy.ndarray]: + """ + Client function for running a forward pass of the server model. + + :param inp: the list of inputs to pass to the server for inference. + The expected order is the inputs order as defined in the ONNX graph + :return: the list of outputs from the server after executing over the inputs + """ + data = self.preprocessing_fn(inp) + response = self._post("run", data=data) + return self.postprocessing_fn(response) + + def _post(self, route: str, data: Any): + route_url = f"{self.url}/{route}" + _LOGGER.debug(f"Sending POST request to {route_url}") + return requests.post(route_url, data=data).content def parse_args(): @@ -57,14 +112,14 @@ def parse_args(): ) parser.add_argument( - "onnx_filepath", + "model_path", type=str, - help="The full filepath of the ONNX model file", + help="The full filepath of the ONNX model file or SparseZoo stub of model", ) parser.add_argument( - "-s", - "--batch_size", + "-b", + "--batch-size", type=int, default=1, help="The batch size to run the analysis for", @@ -89,32 +144,23 @@ def parse_args(): def main(): args = parse_args() - onnx_filepath = args.onnx_filepath - batch_size = args.batch_size - address = args.address - port = args.port - prediction_url = f"http://{address}:{port}/predict" + engine = EngineFlaskClient(args.address, args.port) - inputs = generate_random_inputs(onnx_filepath, batch_size) + inputs = generate_random_inputs(args.model_path, args.batch_size) - print(f"Sending {len(inputs)} input tensors to {prediction_url}") + _LOGGER.info(f"Sending {len(inputs)} input tensors to {engine.url}/run") start = time.time() - # Encode inputs - data = arrays_to_bytes(inputs) - # Send data to server for inference - response = requests.post(prediction_url, data=data) - # Decode outputs - outputs = bytes_to_arrays(response.content) + outputs = engine.run(inputs) end = time.time() elapsed_time = end - start - print(f"Received response of {len(outputs)} output tensors:") - print(f"Round-trip time took {elapsed_time * 1000.0:.4f} milliseconds") + _LOGGER.info(f"Round-trip time took {elapsed_time * 1000.0:.4f} milliseconds") + _LOGGER.info(f"Received response of {len(outputs)} output tensors:") for i, out in enumerate(outputs): - print(f" output #{i}: shape {out.shape}") + _LOGGER.info(f"\toutput #{i}: shape {out.shape}") if __name__ == "__main__": diff --git a/examples/flask/server.py b/examples/flask/server.py index 88d3e9595b..a294cc36f7 100644 --- a/examples/flask/server.py +++ b/examples/flask/server.py @@ -18,21 +18,29 @@ ########## Command help: -usage: server.py [-h] [-s BATCH_SIZE] [-j NUM_CORES] [-a ADDRESS] [-p PORT] - onnx_filepath +usage: server.py [-h] [-b BATCH_SIZE] [-c NUM_CORES] [-s NUM_SOCKETS] + [--scheduler SCHEDULER] [-a ADDRESS] [-p PORT] + model_path Host an ONNX model as a server, using the DeepSparse Engine and Flask positional arguments: - onnx_filepath The full filepath of the ONNX model file + model_path The full filepath of the ONNX model file or SparseZoo + stub for the model optional arguments: -h, --help show this help message and exit - -s BATCH_SIZE, --batch_size BATCH_SIZE - The batch size to run the analysis for - -j NUM_CORES, --num_cores NUM_CORES - The number of physical cores to run the analysis on, + -b BATCH_SIZE, --batch-size BATCH_SIZE + The batch size to run the engine with + -c NUM_CORES, --num-cores NUM_CORES + The number of physical cores to run the engine on, defaults to all physical cores available on the system + -s NUM_SOCKETS, --num-sockets NUM_SOCKETS + The number of physical sockets to run the engine on, + defaults to all physical sockets available on the + system + --scheduler SCHEDULER + The kind of scheduler to run with. Defaults to multi_stream -a ADDRESS, --address ADDRESS The IP address of the hosted model -p PORT, --port PORT The port that the model is hosted on @@ -44,12 +52,72 @@ """ import argparse +import os import flask from flask_cors import CORS -from deepsparse import compile_model -from deepsparse.utils import arrays_to_bytes, bytes_to_arrays +from deepsparse import Scheduler, compile_model +from deepsparse.utils import arrays_to_bytes, bytes_to_arrays, log_init + + +_LOGGER = log_init(os.path.basename(__file__)) + + +def engine_flask_server( + model_path: str, + batch_size: int = 1, + num_cores: int = None, + num_sockets: int = None, + scheduler: Scheduler = Scheduler.multi_stream, + address: str = "0.0.0.0", + port: str = "5543", +) -> flask.Flask: + """ + + :param model_path: Either a path to the model's onnx file, a SparseZoo model stub + prefixed by 'zoo:', a SparseZoo Model object, or a SparseZoo ONNX File + object that defines the neural network + :param batch_size: The batch size of the inputs to be used with the model + :param num_cores: The number of physical cores to run the model on. + Pass None or 0 to run on the max number of cores + in one socket for the current machine, default None + :param num_sockets: The number of physical sockets to run the model on. + Pass None or 0 to run on the max number of sockets for the + current machine, default None + :param scheduler: The kind of scheduler to execute with. Defaults to multi_stream + :param address: IP address to run on. Default is 0.0.0.0 + :param port: port to run on. Default is 5543 + :return: launches a flask server on the given address and port can run the + given model on the DeepSparse engine via HTTP requests + """ + _LOGGER.info(f"Compiling model at {model_path}") + engine = compile_model(model_path, batch_size, num_cores, num_sockets, scheduler) + _LOGGER.info(engine) + + app = flask.Flask(__name__) + CORS(app) + + @app.route("/run", methods=["POST"]) + def run(): + data = flask.request.get_data() + + inputs = bytes_to_arrays(data) + _LOGGER.info(f"Received {len(inputs)} inputs from client") + + _LOGGER.info("Executing model") + outputs, elapsed_time = engine.timed_run(inputs) + + _LOGGER.info(f"Inference time took {elapsed_time * 1000.0:.4f} milliseconds") + _LOGGER.info(f"Produced {len(outputs)} output tensors") + return arrays_to_bytes(outputs) + + @app.route("/info", methods=["GET"]) + def info(): + return flask.jsonify({"model_path": model_path, "engine": repr(engine)}) + + _LOGGER.info("Starting Flask app") + app.run(host=address, port=port, debug=False, threaded=True) def parse_args(): @@ -60,28 +128,44 @@ def parse_args(): ) parser.add_argument( - "onnx_filepath", + "model_path", type=str, - help="The full filepath of the ONNX model file", + help="The full filepath of the ONNX model file or SparseZoo stub for the model", ) parser.add_argument( - "-s", - "--batch_size", + "-b", + "--batch-size", type=int, default=1, - help="The batch size to run the analysis for", + help="The batch size to run the engine with", ) parser.add_argument( - "-j", - "--num_cores", + "-c", + "--num-cores", type=int, default=0, help=( - "The number of physical cores to run the analysis on, " + "The number of physical cores to run the engine on, " "defaults to all physical cores available on the system" ), ) + parser.add_argument( + "-s", + "--num-sockets", + type=int, + default=None, + help=( + "The number of physical sockets to run the engine on, " + "defaults to all physical sockets available on the system" + ), + ) + parser.add_argument( + "--scheduler", + type=str, + default=Scheduler.multi_stream, + help="The kind of scheduler to run with. Defaults to multi_stream", + ) parser.add_argument( "-a", "--address", @@ -100,47 +184,18 @@ def parse_args(): return parser.parse_args() -def create_model_inference_app( - model_path: str, batch_size: int, num_cores: int, address: str, port: str -) -> flask.Flask: - print(f"Compiling model at {model_path}") - engine = compile_model(model_path, batch_size, num_cores) - print(engine) - - app = flask.Flask(__name__) - CORS(app) - - @app.route("/predict", methods=["POST"]) - def predict(): - data = flask.request.get_data() - - inputs = bytes_to_arrays(data) - print(f"Received {len(inputs)} inputs from client") - - print("Executing model") - outputs, elapsed_time = engine.timed_run(inputs) - - print(f"Inference time took {elapsed_time * 1000.0:.4f} milliseconds") - print(f"Produced {len(outputs)} output tensors") - return arrays_to_bytes(outputs) - - @app.route("/info", methods=["GET"]) - def info(): - return flask.jsonify({"model_path": model_path, "engine": repr(engine)}) - - print("Starting Flask app") - app.run(host=address, port=port, debug=False, threaded=True) - - def main(): args = parse_args() - onnx_filepath = args.onnx_filepath - batch_size = args.batch_size - num_cores = args.num_cores - address = args.address - port = args.port - create_model_inference_app(onnx_filepath, batch_size, num_cores, address, port) + engine_flask_server( + args.model_path, + args.batch_size, + args.num_cores, + args.num_sockets, + args.scheduler, + args.address, + args.port, + ) if __name__ == "__main__":