From f99c0364c95829d5dacefee669f29b7803faaad7 Mon Sep 17 00:00:00 2001 From: Ceyda Cinarel Date: Thu, 15 Oct 2020 16:57:42 +0900 Subject: [PATCH] add scale workers tab --- torchserve_dashboard/api.py | 16 ++++++++++++++++ torchserve_dashboard/dash.py | 32 +++++++++++++++++++++++++++++++- 2 files changed, 47 insertions(+), 1 deletion(-) diff --git a/torchserve_dashboard/api.py b/torchserve_dashboard/api.py index 7b8a09f..2bca0e7 100644 --- a/torchserve_dashboard/api.py +++ b/torchserve_dashboard/api.py @@ -5,6 +5,7 @@ import streamlit as st + def raise_on_not200(response): if response.status_code != 200: st.write("There was an error!") @@ -98,3 +99,18 @@ def change_model_default(M_API, model_name, version): req_url += "/set-default" res = client.put(req_url) return res.json() + + +def change_model_workers(M_API, model_name, version=None, min_worker=None, max_worker=None, number_gpu=None): + req_url = M_API + "/models/" + model_name + if version: + req_url += "/" + version + req_url += "?synchronous=false" + if min_worker: + req_url += "&min_worker=" + str(min_worker) + if max_worker: + req_url += "&max_worker=" + str(max_worker) + if number_gpu: + req_url += "&number_gpu=" + str(number_gpu) + res = client.put(req_url) + return res.json() diff --git a/torchserve_dashboard/dash.py b/torchserve_dashboard/dash.py index 0e25386..9d4a3c8 100644 --- a/torchserve_dashboard/dash.py +++ b/torchserve_dashboard/dash.py @@ -144,7 +144,6 @@ def get_model_store(): if proceed and model_name != default_key and version != default_key: res = tsa.delete_model(M_API, model_name, version) last_res()[0] = res - proceed=False rerun() with st.beta_expander(label="Get model details", expanded=False): @@ -164,3 +163,34 @@ def get_model_store(): elif version != default_key: res = tsa.get_model(M_API, model_name, version) st.write(res) + + with st.beta_expander(label="Scale workers", expanded=False): + st.markdown("# Scale workers [(docs)](https://pytorch.org/serve/management_api.html#scale-workers)") + model_name = st.selectbox("Pick model", [default_key] + loaded_models_names, index=0) + if model_name != default_key: + default_version = tsa.get_model(M_API,model_name)[0]["modelVersion"] + st.write(f"default version {default_version}") + versions = tsa.get_model(M_API,model_name, list_all=False) + versions = [m["modelVersion"] for m in versions] + version = st.selectbox("Choose version", ["All"] + versions, index=0) + + col1, col2, col3 = st.beta_columns(3) + min_worker = col1.number_input(label="min_worker(optional)", value=-1, min_value=-1, step=1) + max_worker = col2.number_input(label="max_worker(optional)", value=-1, min_value=-1, step=1) + number_gpu = col3.number_input(label="number_gpu(optional)", value=-1, min_value=-1, step=1) + proceed = st.button("Apply") + if proceed and model_name != default_key: + # number_input can't be set to None + if version == "All": + version=None + if min_worker == -1: + min_worker=None + if max_worker == -1: + max_worker=None + if number_gpu == -1: + number_gpu=None + + res = tsa.change_model_workers(M_API, model_name, version=version, min_worker=min_worker, max_worker=max_worker, number_gpu=number_gpu) + last_res()[0] = res + rerun() +