-
Notifications
You must be signed in to change notification settings - Fork 358
/
tests.py
68 lines (62 loc) · 2.04 KB
/
tests.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from utils import (
generate_together,
generate_openai,
inject_references_to_messages,
generate_with_references,
)
if __name__ == "__main__":
#####
messages = [{"role": "user", "content": "hello!"}]
output = generate_together(
"meta-llama/Llama-3-8b-chat-hf",
messages,
temperature=0,
)
assert (
output.strip()
== "Hello! It's nice to meet you. Is there something I can help you with, or would you like to chat?"
)
print("#1 pass")
#####
messages = [{"role": "user", "content": "hello!"}]
output = generate_openai(
"gpt-3.5-turbo",
messages,
temperature=0,
)
assert output.strip() == "Hello! How can I assist you today?"
print("#2 pass")
#####
messages = [{"role": "user", "content": "hello!"}]
messages = inject_references_to_messages(
messages,
["Hello! How can I help you today?", "Hello! How can I assist you today?"],
)
assert len(messages) == 2
assert messages[0]["role"] == "system"
output = generate_together(
"meta-llama/Llama-3-8b-chat-hf",
messages,
temperature=0,
)
assert (
output.strip()
== "Hello! It seems like you're looking for assistance with something. I'm here to help! Could you please provide more context or clarify what's on your mind? I'll do my best to offer a helpful and accurate response."
)
print("#3 pass")
####
messages = [{"role": "user", "content": "hello!"}]
output = generate_with_references(
"meta-llama/Llama-3-8b-chat-hf",
messages,
references=[
"Hello! How can I help you today?",
"Hello! How can I assist you today?",
],
temperature=0,
)
assert (
output.strip()
== "Hello! It seems like you're looking for assistance with something. I'm here to help! Could you please provide more context or clarify what's on your mind? I'll do my best to offer a helpful and accurate response."
)
print("#4 pass")