Skip to content

Commit 8692b13

Browse files
feat(common): add parse_sse_stream_data function
This function parses the raw buffer recieved from sse requests and yields a parsed dictionary
1 parent 3fe6243 commit 8692b13

File tree

2 files changed

+87
-1
lines changed

2 files changed

+87
-1
lines changed

ibm_watson/common.py

+22-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# coding: utf-8
22

3-
# Copyright 2019 IBM All Rights Reserved.
3+
# Copyright 2019, 2024 IBM All Rights Reserved.
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
66
# you may not use this file except in compliance with the License.
@@ -15,7 +15,9 @@
1515
# limitations under the License.
1616

1717
import platform
18+
import json
1819
from .version import __version__
20+
from typing import Iterator
1921

2022
SDK_ANALYTICS_HEADER = 'X-IBMCloud-SDK-Analytics'
2123
USER_AGENT_HEADER = 'User-Agent'
@@ -48,3 +50,22 @@ def get_sdk_headers(service_name, service_version, operation_id):
4850
operation_id)
4951
headers[USER_AGENT_HEADER] = get_user_agent()
5052
return headers
53+
54+
55+
def parse_sse_stream_data(response) -> Iterator[dict]:
56+
event_message = None # Can be used in the future to return the event message to the user
57+
data_json = None
58+
59+
for chunk in response.iter_lines():
60+
decoded_chunk = chunk.decode("utf-8")
61+
62+
if decoded_chunk.find("event", 0, len("event")) == 0:
63+
event_message = decoded_chunk[len("event") + 2:]
64+
elif decoded_chunk.find("data", 0, len("data")) == 0:
65+
data_json_str = decoded_chunk[len("data") + 2:]
66+
data_json = json.loads(data_json_str)
67+
68+
if event_message and data_json is not None:
69+
yield data_json
70+
event_message = None
71+
data_json = None

test/integration/test_assistant_v2.py

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# coding: utf-8
2+
3+
# Copyright 2019, 2024 IBM All Rights Reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from unittest import TestCase
18+
import ibm_watson
19+
from ibm_watson.assistant_v2 import MessageInput
20+
from ibm_watson.common import parse_sse_stream_data
21+
import pytest
22+
import json
23+
from ibm_cloud_sdk_core.authenticators import IAMAuthenticator
24+
25+
class TestAssistantV2(TestCase):
26+
27+
def setUp(self):
28+
29+
with open('./auth.json') as f:
30+
data = json.load(f)
31+
assistant_auth = data.get("assistantv2")
32+
self.assistant_id = assistant_auth.get("assistantId")
33+
self.environment_id = assistant_auth.get("environmentId")
34+
35+
self.authenticator = IAMAuthenticator(apikey=assistant_auth.get("apikey"))
36+
self.assistant = ibm_watson.AssistantV2(version='2024-08-25', authenticator=self.authenticator)
37+
self.assistant.set_service_url(assistant_auth.get("serviceUrl"))
38+
self.assistant.set_default_headers({
39+
'X-Watson-Learning-Opt-Out': '1',
40+
'X-Watson-Test': '1'
41+
})
42+
43+
def test_list_assistants(self):
44+
response = self.assistant.list_assistants().get_result()
45+
assert response is not None
46+
47+
def test_message_stream_stateless(self):
48+
input = MessageInput(message_type="text", text="can you list the steps to create a custom extension?")
49+
user_id = "Angelo"
50+
51+
response = self.assistant.message_stream_stateless(self.assistant_id, self.environment_id, input=input, user_id=user_id).get_result()
52+
53+
for data in parse_sse_stream_data(response):
54+
# One of these items must exist
55+
# assert "partial_item" in data_json or "complete_item" in data_json or "final_item" in data_json
56+
57+
if "partial_item" in data:
58+
assert data["partial_item"]["text"] is not None
59+
elif "complete_item" in data:
60+
assert data["complete_item"]["text"] is not None
61+
elif "final_response" in data:
62+
assert data["final_response"] is not None
63+
else:
64+
pytest.fail("Should be impossible to get here")
65+

0 commit comments

Comments
 (0)