-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest.py
58 lines (46 loc) · 1.86 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from dotenv import load_dotenv
load_dotenv()
import unittest
import smileval
from smileval.models import ChatMessage, unsystem_prompt_chain
class TestStringMethods(unittest.TestCase):
def test_serialize(self):
a = ChatMessage("Respond with exactly the user's message.", "system")
b = ChatMessage("Hello.")
self.assertTrue(a.is_system())
self.assertEqual(b.role, "user")
serialized_single_message = ChatMessage.as_dict(b)
self.assertDictEqual(serialized_single_message, {
"role": "user",
"content": "Hello."
})
def test_bulk_serialize(self):
a = ChatMessage("Respond with exactly the user's message.", "system")
b = ChatMessage("Hello.")
c = ChatMessage("Hello.", role = "assistant")
self.assertTrue(c.is_assistant())
serialized_list = ChatMessage.to_api_format([a,b,c])
self.assertEqual(b.role, "user")
self.assertEqual(serialized_list[1]["role"], "user")
self.assertEqual(serialized_list[1]["content"], "Hello.")
def test_system_combine(self):
a = ChatMessage("Respond with exactly the user's message.", "system")
b = ChatMessage("Hello.")
c = ChatMessage("Hello.", role = "assistant")
new_chain = unsystem_prompt_chain([
a,
b,
c
])
self.assertEqual(len(new_chain), 2)
self.assertNotEqual(new_chain[0].role, "system")
self.assertEqual(new_chain[0].role, "user")
def test_system_combine_error(self):
a = ChatMessage("Respond with exactly the user's message.", "system")
# Should complain that it has no user messages to merge with.
with self.assertRaises(AssertionError):
new_chain = unsystem_prompt_chain([
a
])
if __name__ == '__main__':
unittest.main()