Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Tabrizian committed Apr 16, 2024
1 parent dcc6aa6 commit 8441315
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 21 deletions.
19 changes: 11 additions & 8 deletions Conceptual_Guide/Part_7-iterative_scheduling/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,17 @@
# Deploying a GPT-2 Model using Python Backend and Iterative Scheduling

In this tutorial, we will deploy a GPT-2 model using the Python backend and
[iterative scheduling](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#iterative-sequences).
demonstrate the
[iterative scheduling](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html#iterative-sequences)
feature.

## Prerequisites

Before getting started with this tutorial, make sure you're familiar
with the following concepts:

* [Python Backend](https://github.com/triton-inference-server/python_backend)
* [Triton-Server Quick Start](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/getting_started/quickstart.html)
* [Python Backend](https://github.com/triton-inference-server/python_backend)

## Iterative Scheduling

Expand All @@ -56,9 +58,13 @@ In this tutorial we deploy two models:
* simple-gpt2: This model receives a batch of requests and proceeds to the next
batch only when it is done generating tokens for the current batch.

* iterative-gpt2: This model uses iterative scheduling and is able to process
new sequences in the batch even when it is still generating tokens for the
previous sequences.
* iterative-gpt2: This model uses iterative scheduling to process
new sequences in a batch even when it is still generating tokens for the
previous sequences

### Demo

[![asciicast](https://asciinema.org/a/WeDlFwRuOip6q7EgQMZmWKRE7.svg)](https://asciinema.org/a/WeDlFwRuOip6q7EgQMZmWKRE7)


### Step 1: Prepare the Server Environment
Expand Down Expand Up @@ -121,9 +127,6 @@ python3 client/client.py --model iterative-gpt2

As you can see, the tokens for both prompts are getting generated simultaneously.

### Demo

[![asciicast](https://asciinema.org/a/WeDlFwRuOip6q7EgQMZmWKRE7.svg)](https://asciinema.org/a/WeDlFwRuOip6q7EgQMZmWKRE7)

## Next Steps

Expand Down
5 changes: 2 additions & 3 deletions Conceptual_Guide/Part_7-iterative_scheduling/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
import threading
import time
from functools import partial
from threading import Event

import numpy as np
import tritonclient.grpc as grpcclient
Expand Down Expand Up @@ -70,8 +69,8 @@ def run_inferences(url, model_name, display):
inputs1.append(grpcclient.InferInput("text_input", [1, 1], "BYTES"))
inputs1[0].set_data_from_numpy(np.array([[prompt2]], dtype=np.object_))

event1 = Event()
event2 = Event()
event1 = threading.Event()
event2 = threading.Event()
client1.start_stream(callback=partial(partial(client1_callback, display), event1))
client2.start_stream(callback=partial(partial(client2_callback, display), event2))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def auto_complete_config(config):
for output in outputs:
config.add_output(output)

# Enable decoupled mode
transaction_policy = {"decoupled": True}
config.set_model_transaction_policy(transaction_policy)
config.set_max_batch_size(8)

return config

def create_batch(self, requests):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

backend: "python"
max_batch_size: 8
model_transaction_policy {
decoupled: true
}
sequence_batching {
iterative_sequence: true
control_input: [{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ def auto_complete_config(config):
for output in outputs:
config.add_output(output)

transaction_policy = {"decoupled": True}
config.set_dynamic_batching()
config.set_max_batch_size(8)
config.set_model_transaction_policy(transaction_policy)

return config

def init_state(self, requests):
Expand All @@ -84,6 +89,7 @@ def create_batch(self, requests):
Returns:
input_ids (torch.Tensor): A tensor containing the processed input IDs.
attention_mask (torch.Tensor): A tensor containing the attention mask.
mapping (list): A list of indices that map the input tensors to the requests.
"""

input_ids = []
Expand Down Expand Up @@ -134,6 +140,7 @@ def send_responses(self, requests, outputs, mapping):
Args:
requests (list): List of Triton InferenceRequest objects.
outputs (list): List of output tensors generated by the model.
mapping (list): List of indices that map the input tensors to the requests.
Returns:
None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,6 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

backend: "python"
max_batch_size: 8
model_transaction_policy {
decoupled: true
}

dynamic_batching {}

instance_group [
{
Expand Down

0 comments on commit 8441315

Please sign in to comment.