Skip to content

Commit daa5518

Browse files
authored
model server might have already done a serialization. honor that by not decoding the request again if it is not already bytes or bytestream (#4987)
1 parent 2102bb7 commit daa5518

File tree

3 files changed

+48
-18
lines changed

3 files changed

+48
-18
lines changed

src/sagemaker/serve/model_server/multi_model_server/inference.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -46,18 +46,28 @@ def input_fn(input_data, content_type, context=None):
4646
if hasattr(schema_builder, "custom_input_translator"):
4747
deserialized_data = schema_builder.custom_input_translator.deserialize(
4848
(
49-
io.BytesIO(input_data)
50-
if type(input_data) == bytes
51-
else io.BytesIO(input_data.encode("utf-8"))
49+
io.BytesIO(input_data.encode("utf-8"))
50+
if not any(
51+
[
52+
isinstance(input_data, bytes),
53+
isinstance(input_data, bytearray),
54+
]
55+
)
56+
else io.BytesIO(input_data)
5257
),
5358
content_type,
5459
)
5560
else:
5661
deserialized_data = schema_builder.input_deserializer.deserialize(
5762
(
58-
io.BytesIO(input_data)
59-
if type(input_data) == bytes
60-
else io.BytesIO(input_data.encode("utf-8"))
63+
io.BytesIO(input_data.encode("utf-8"))
64+
if not any(
65+
[
66+
isinstance(input_data, bytes),
67+
isinstance(input_data, bytearray),
68+
]
69+
)
70+
else io.BytesIO(input_data)
6171
),
6272
content_type[0],
6373
)

src/sagemaker/serve/model_server/torchserve/inference.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,28 @@ def input_fn(input_data, content_type):
6868
if hasattr(schema_builder, "custom_input_translator"):
6969
deserialized_data = schema_builder.custom_input_translator.deserialize(
7070
(
71-
io.BytesIO(input_data)
72-
if type(input_data) == bytes
73-
else io.BytesIO(input_data.encode("utf-8"))
71+
io.BytesIO(input_data.encode("utf-8"))
72+
if not any(
73+
[
74+
isinstance(input_data, bytes),
75+
isinstance(input_data, bytearray),
76+
]
77+
)
78+
else io.BytesIO(input_data)
7479
),
7580
content_type,
7681
)
7782
else:
7883
deserialized_data = schema_builder.input_deserializer.deserialize(
7984
(
80-
io.BytesIO(input_data)
81-
if type(input_data) == bytes
82-
else io.BytesIO(input_data.encode("utf-8"))
85+
io.BytesIO(input_data.encode("utf-8"))
86+
if not any(
87+
[
88+
isinstance(input_data, bytes),
89+
isinstance(input_data, bytearray),
90+
]
91+
)
92+
else io.BytesIO(input_data)
8393
),
8494
content_type[0],
8595
)

src/sagemaker/serve/model_server/torchserve/xgboost_inference.py

+16-6
Original file line numberDiff line numberDiff line change
@@ -71,18 +71,28 @@ def input_fn(input_data, content_type):
7171
if hasattr(schema_builder, "custom_input_translator"):
7272
return schema_builder.custom_input_translator.deserialize(
7373
(
74-
io.BytesIO(input_data)
75-
if type(input_data) == bytes
76-
else io.BytesIO(input_data.encode("utf-8"))
74+
io.BytesIO(input_data.encode("utf-8"))
75+
if not any(
76+
[
77+
isinstance(input_data, bytes),
78+
isinstance(input_data, bytearray),
79+
]
80+
)
81+
else io.BytesIO(input_data)
7782
),
7883
content_type,
7984
)
8085
else:
8186
return schema_builder.input_deserializer.deserialize(
8287
(
83-
io.BytesIO(input_data)
84-
if type(input_data) == bytes
85-
else io.BytesIO(input_data.encode("utf-8"))
88+
io.BytesIO(input_data.encode("utf-8"))
89+
if not any(
90+
[
91+
isinstance(input_data, bytes),
92+
isinstance(input_data, bytearray),
93+
]
94+
)
95+
else io.BytesIO(input_data)
8696
),
8797
content_type[0],
8898
)

0 commit comments

Comments
 (0)