Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Enable deferred unregistering of shared memory regions after inference #7743

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
429 changes: 246 additions & 183 deletions qa/L0_cuda_shared_memory/cuda_shared_memory_test.py

Large diffs are not rendered by default.

17 changes: 16 additions & 1 deletion qa/L0_cuda_shared_memory/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ for client_type in http grpc; do
fi

export CLIENT_TYPE=$client_type
CLIENT_LOG="./unregister_shm.$client_type.client.log"
CLIENT_LOG="./unregister_shm_during_inference_$client_type.client.log"
set +e
python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_during_inference_$client_type >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
Expand All @@ -116,6 +116,21 @@ for client_type in http grpc; do
fi
fi

CLIENT_LOG="./unregister_shm_after_inference_$client_type.client.log"
python3 $SHM_TEST TestCudaSharedMemoryUnregister.test_unregister_shm_after_inference_$client_type >>$CLIENT_LOG 2>&1
if [ $? -ne 0 ]; then
cat $CLIENT_LOG
echo -e "\n***\n*** Test Failed\n***"
RET=1
else
check_test_results $TEST_RESULT_FILE 1
if [ $? -ne 0 ]; then
cat $TEST_RESULT_FILE
echo -e "\n***\n*** Test Result Verification Failed\n***"
RET=1
fi
fi

kill $SERVER_PID
wait $SERVER_PID
if [ $? -ne 0 ]; then
Expand Down
321 changes: 186 additions & 135 deletions qa/L0_shared_memory/shared_memory_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,167 +457,218 @@ def test_python_client_leak(self):


class TestSharedMemoryUnregister(SystemSharedMemoryTestBase):
def _test_unregister_shm_fail(self):
def _test_unregister_shm_request_pass(self, shm_names):
self._test_shm_found(shm_names)

# Unregister all should not result in an error.
# If shared memory regions are in use, they will be marked and unregistered after the inference is completed.
with httpclient.InferenceServerClient(
"localhost:8000", verbose=True
) as second_client:
second_client.unregister_system_shared_memory()

# Number of shared memory regions should be the same as the inference is not completed yet
self._test_shm_found(shm_names)

def _test_shm_not_found(self, shm_names):
self.assertGreater(len(shm_names), 0)
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.unregister_system_shared_memory()
self.assertIn(
"Failed to unregister the following system shared memory regions: input0_data ,input1_data ,output0_data ,output1_data",
str(ex.exception),
)
for shm_name in shm_names:
with self.assertRaises(utils.InferenceServerException) as ex:
second_client.get_system_shared_memory_status(shm_name)
self.assertIn(
f"Unable to find system shared memory region: '{shm_name}'",
str(ex.exception),
)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.unregister_system_shared_memory("input0_data")
self.assertIn(
"Cannot unregister shared memory region 'input0_data', it is currently in use.",
str(ex.exception),
)
def _test_shm_found(self, shm_names):
self.assertGreater(len(shm_names), 0)
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.unregister_system_shared_memory("input1_data")
self.assertIn(
"Cannot unregister shared memory region 'input1_data', it is currently in use.",
str(ex.exception),
)
status = second_client.get_system_shared_memory_status()
self.assertEqual(len(status), len(shm_names))

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.unregister_system_shared_memory("output0_data")
self.assertIn(
"Cannot unregister shared memory region 'output0_data', it is currently in use.",
str(ex.exception),
)
for shm_info in status:
self.assertIn(shm_info["name"], shm_names)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.unregister_system_shared_memory("output1_data")
self.assertIn(
"Cannot unregister shared memory region 'output1_data', it is currently in use.",
str(ex.exception),
)
def test_unregister_shm_during_inference_http(self):
self.triton_client.unregister_system_shared_memory()
self._configure_server()
shm_names = ["input0_data", "input1_data", "output0_data", "output1_data"]

def _test_shm_not_found(self):
second_client = httpclient.InferenceServerClient("localhost:8000", verbose=True)
inputs = [
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
]

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.get_system_shared_memory_status("input0_data")
self.assertIn(
"Unable to find system shared memory region: 'input0_data'",
str(ex.exception),
)
inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.get_system_shared_memory_status("input1_data")
self.assertIn(
"Unable to find system shared memory region: 'input1_data'",
str(ex.exception),
async_request = self.triton_client.async_infer(
model_name="simple", inputs=inputs, outputs=outputs
)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.get_system_shared_memory_status("output0_data")
self.assertIn(
"Unable to find system shared memory region: 'output0_data'",
str(ex.exception),
# Ensure inference started
time.sleep(2)

# Try unregister shm regions during inference
self._test_unregister_shm_request_pass(shm_names)

# Blocking call
async_request.get_result()

# Test that all shm regions are successfully unregistered after inference without needing to call unregister again.
self._test_shm_not_found(shm_names)

def test_unregister_shm_after_inference_http(self):
self.triton_client.unregister_system_shared_memory()
self._configure_server()
shm_names = ["input0_data", "input1_data", "output0_data", "output1_data"]

inputs = [
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

async_request = self.triton_client.async_infer(
model_name="simple", inputs=inputs, outputs=outputs
)

with self.assertRaises(utils.InferenceServerException) as ex:
second_client.get_system_shared_memory_status("output1_data")
self.assertIn(
"Unable to find system shared memory region: 'output1_data'",
str(ex.exception),
# Ensure inference started
time.sleep(2)

# Test all registered shm regions exist during inference.
self._test_shm_found(shm_names)

# Blocking call
async_request.get_result()

# Test all registered shm regions exist after inference, as unregister API have not been called.
self._test_shm_found(shm_names)

# Test all shm regions are successfully unregistered after calling the unregister API after inference completed.
self.triton_client.unregister_system_shared_memory()
self._test_shm_not_found(shm_names)

def test_unregister_shm_during_inference_grpc(self):
self.triton_client.unregister_system_shared_memory()
self._configure_server()
shm_names = ["input0_data", "input1_data", "output0_data", "output1_data"]

inputs = [
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

def callback(user_data, result, error):
if error:
user_data.append(error)
else:
user_data.append(result)

user_data = []

self.triton_client.async_infer(
model_name="simple",
inputs=inputs,
outputs=outputs,
callback=partial(callback, user_data),
)

def test_unregister_shm_during_inference_http(self):
try:
self.triton_client.unregister_system_shared_memory()
self._configure_server()

inputs = [
httpclient.InferInput("INPUT0", [1, 16], "INT32"),
httpclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
httpclient.InferRequestedOutput("OUTPUT0", binary_data=True),
httpclient.InferRequestedOutput("OUTPUT1", binary_data=False),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

async_request = self.triton_client.async_infer(
model_name="simple", inputs=inputs, outputs=outputs
)
# Ensure inference started
time.sleep(2)

# Ensure inference started
time.sleep(2)
# Try unregister shm regions during inference
self._test_unregister_shm_request_pass(shm_names)

# Try unregister shm regions during inference
self._test_unregister_shm_fail()
# Wait until the results are available in user_data
time_out = 20
while (len(user_data) == 0) and time_out > 0:
time_out = time_out - 1
time.sleep(1)
time.sleep(2)

# Blocking call
async_request.get_result()
# Test that all shm regions are successfully unregistered after inference without needing to call unregister again.
self._test_shm_not_found(shm_names)

# Try unregister shm regions after inference
self.triton_client.unregister_system_shared_memory()
self._test_shm_not_found()
def test_unregister_shm_after_inference_grpc(self):
self.triton_client.unregister_system_shared_memory()
self._configure_server()
shm_names = ["input0_data", "input1_data", "output0_data", "output1_data"]

finally:
self._cleanup_shm_handles()
inputs = [
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]

def test_unregister_shm_during_inference_grpc(self):
try:
self.triton_client.unregister_system_shared_memory()
self._configure_server()

inputs = [
grpcclient.InferInput("INPUT0", [1, 16], "INT32"),
grpcclient.InferInput("INPUT1", [1, 16], "INT32"),
]
outputs = [
grpcclient.InferRequestedOutput("OUTPUT0"),
grpcclient.InferRequestedOutput("OUTPUT1"),
]

inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

def callback(user_data, result, error):
if error:
user_data.append(error)
else:
user_data.append(result)

user_data = []

self.triton_client.async_infer(
model_name="simple",
inputs=inputs,
outputs=outputs,
callback=partial(callback, user_data),
)
inputs[0].set_shared_memory("input0_data", self.DEFAULT_SHM_BYTE_SIZE)
inputs[1].set_shared_memory("input1_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[0].set_shared_memory("output0_data", self.DEFAULT_SHM_BYTE_SIZE)
outputs[1].set_shared_memory("output1_data", self.DEFAULT_SHM_BYTE_SIZE)

def callback(user_data, result, error):
if error:
user_data.append(error)
else:
user_data.append(result)

# Ensure inference started
time.sleep(2)
user_data = []

# Try unregister shm regions during inference
self._test_unregister_shm_fail()
self.triton_client.async_infer(
model_name="simple",
inputs=inputs,
outputs=outputs,
callback=partial(callback, user_data),
)

# Ensure inference started
time.sleep(2)

# Wait until the results are available in user_data
time_out = 20
while (len(user_data) == 0) and time_out > 0:
time_out = time_out - 1
time.sleep(1)
time.sleep(2)
# Test all registered shm regions exist during inference.
self._test_shm_found(shm_names)

# Try unregister shm regions after inference
self.triton_client.unregister_system_shared_memory()
self._test_shm_not_found()
# Wait until the results are available in user_data
time_out = 20
while (len(user_data) == 0) and time_out > 0:
time_out = time_out - 1
time.sleep(1)
time.sleep(2)

finally:
self._cleanup_shm_handles()
# Test all registered shm regions exist after inference, as unregister API have not been called.
self._test_shm_found(shm_names)

# Test all shm regions are successfully unregistered after calling the unregister API after inference completed.
self.triton_client.unregister_system_shared_memory()
self._test_shm_not_found(shm_names)


if __name__ == "__main__":
Expand Down
Loading
Loading