Skip to content

Commit

Permalink
fix chat tests due to conversation_id being optional
Browse files Browse the repository at this point in the history
  • Loading branch information
lfayoux committed Aug 14, 2023
1 parent ee55ca9 commit f71644d
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 18 deletions.
2 changes: 1 addition & 1 deletion cohere/responses/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def from_dict(cls, response: Dict[str, Any], message: str, client) -> "Chat":
response_id=response["response_id"],
generation_id=response["generation_id"],
message=message,
conversation_id=response["conversation_id"],
conversation_id=response.get("conversation_id"), # optional
text=response.get("text"),
prompt=response.get("prompt"), # optional
chatlog=response.get("chatlog"), # optional
Expand Down
17 changes: 0 additions & 17 deletions tests/sync/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ class TestChat(unittest.TestCase):
def test_simple_success(self):
prediction = co.chat("Yo what up?", max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertTrue(prediction.meta)
self.assertTrue(prediction.meta["api_version"])
self.assertTrue(prediction.meta["api_version"]["version"])
Expand All @@ -23,12 +22,10 @@ def test_multi_replies(self):
for _ in range(num_replies):
prediction = prediction.respond("oh that's cool", max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_valid_model(self):
prediction = co.chat("Yo what up?", model="medium", max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_invalid_model(self):
with self.assertRaises(cohere.CohereError):
Expand All @@ -37,28 +34,24 @@ def test_invalid_model(self):
def test_return_chatlog(self):
prediction = co.chat("Yo what up?", return_chatlog=True, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertIsNotNone(prediction.chatlog)
self.assertGreaterEqual(len(prediction.chatlog), len(prediction.text))

def test_return_chatlog_false(self):
prediction = co.chat("Yo what up?", return_chatlog=False, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

assert prediction.chatlog is None

def test_return_prompt(self):
prediction = co.chat("Yo what up?", return_prompt=True, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertIsNotNone(prediction.prompt)
self.assertGreaterEqual(len(prediction.prompt), len(prediction.text))

def test_return_prompt_false(self):
prediction = co.chat("Yo what up?", return_prompt=False, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
assert prediction.prompt is None

def test_preamble_override(self):
Expand All @@ -67,7 +60,6 @@ def test_preamble_override(self):
"Yo what up?", preamble_override=preamble, return_prompt=True, return_preamble=True, max_tokens=5
)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertIn(preamble, prediction.prompt)
self.assertEqual(preamble, prediction.preamble)

Expand All @@ -82,7 +74,6 @@ def test_valid_temperatures(self):
for temperature in temperatures:
prediction = co.chat("Yo what up?", temperature=temperature, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_stream(self):
prediction = co.chat(
Expand All @@ -94,7 +85,6 @@ def test_stream(self):
self.assertIsInstance(prediction, cohere.responses.chat.StreamingChat)
self.assertIsInstance(prediction.texts, list)
self.assertEqual(len(prediction.texts), 0)
self.assertIsNone(prediction.conversation_id)
self.assertIsNone(prediction.response_id)
self.assertIsNone(prediction.finish_reason)

Expand All @@ -111,7 +101,6 @@ def test_stream(self):
expected_index += 1

self.assertEqual(prediction.texts, [expected_text])
self.assertIsNotNone(prediction.conversation_id)
self.assertIsNotNone(prediction.response_id)
self.assertIsNotNone(prediction.finish_reason)

Expand All @@ -127,15 +116,13 @@ def test_id(self):
def test_return_preamble(self):
prediction = co.chat("Yo what up?", return_preamble=True, return_prompt=True, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertIsNotNone(prediction.preamble)
self.assertIsNotNone(prediction.prompt)
self.assertIn(prediction.preamble, prediction.prompt)

def test_return_preamble_false(self):
prediction = co.chat("Yo what up?", return_preamble=False, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

assert prediction.preamble is None

Expand All @@ -151,7 +138,6 @@ def test_chat_history(self):
max_tokens=5,
)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)
self.assertIsNotNone(prediction.chatlog)
self.assertIn("User: Hey!", prediction.prompt)
self.assertIn("Chatbot: Hey! How can I help you?", prediction.prompt)
Expand Down Expand Up @@ -181,7 +167,6 @@ def test_token_count(self):
def test_p(self):
prediction = co.chat("Yo what up?", p=0.9, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_invalid_p(self):
with self.assertRaises(cohere.error.CohereError):
Expand All @@ -190,7 +175,6 @@ def test_invalid_p(self):
def test_k(self):
prediction = co.chat("Yo what up?", k=5, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_invalid_k(self):
with self.assertRaises(cohere.error.CohereError):
Expand All @@ -199,7 +183,6 @@ def test_invalid_k(self):
def test_logit_bias(self):
prediction = co.chat("Yo what up?", logit_bias={42: 10}, max_tokens=5)
self.assertIsInstance(prediction.text, str)
self.assertIsInstance(prediction.conversation_id, str)

def test_invalid_logit_bias(self):
invalid = [
Expand Down

0 comments on commit f71644d

Please sign in to comment.