-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
122 lines (95 loc) · 3.59 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import asyncio
import gradio as gr
import httpx
import os
from pydantic import BaseModel
from typing import List, Optional, Any
API_KEY = os.getenv("OPENAI_API_KEY")
class BasicInfo(BaseModel):
current_topic_index: int = 0
is_passed: bool = False
is_finished: bool = False
attempt_times: int = 0
passed_count: int = 0
class Message(BaseModel):
role: str
content: str
def update_current_index(index: int) -> str:
return f"""<h2><center>第 {index} 题 / 共 5 题</center></h2>"""
def update_current_problem(problem: str) -> str:
return problem
def update_current_rules(rules: List) -> str:
return f"""
<h3><font size=4.75rem>要求</h3>
<p><font size=3.5rem>{"<br>".join(rules)}</p>
"""
def varify_input(
topic_limits: dict[str, int | list[str] | str] | None,
input_: str) -> bool:
result = True
if input_ == "":
result = False
if topic_limits['words_count']:
result = (len(input_) < topic_limits['words_count'])
if topic_limits['ban_words']:
for ban_word in topic_limits['ban_words']:
if ban_word in input_:
result = False
break
if topic_limits['contain_words']:
for contain_word in topic_limits['contain_words']:
if contain_word not in input_:
result = False
break
return result
async def init_chat(
topic_limits: dict[str, int | list[str] | str] | None,
history: Any) -> (List[str], Any):
messages = []
if topic_limits['premise']:
_, messages, history = await get_response(topic_limits['premise'], history)
return messages, history
async def get_response(
input_: str,
history: Any) -> (str, List[str], Any):
history.append({"role": "user", "content": input_})
response = await chat_interface(history)
history.append({"role": "assistant", "content": response})
messages = [(history[i]["content"], history[i+1]["content"]) for i in range(0, len(history)-1, 2)]
return "", messages, history
async def chat_interface(
messages: List[Message],
number_retries: int = 5) -> Optional[str]:
header = {
"Content-Type": "application/json",
"Authorization": f"Bearer {API_KEY}"
}
try:
async with httpx.AsyncClient(headers=header) as aio_client:
counter = 0
keep_loop = True
timeout_seconds = 100
while keep_loop:
# gr.Info(f"Chat/Completions Nb Retries : {counter}")
try:
resp = await aio_client.post(
url="https://chat-api.cx0.cc/v1/chat/completions",
json={
"model": "gpt-4",
"messages": messages
},
timeout=timeout_seconds
)
gr.Info(f"Status Code : {resp.status_code}")
if resp.status_code == 200:
return resp.json()["choices"][0]["message"]["content"]
else:
gr.Warning(f"{resp.content}")
keep_loop = False
except Exception as e:
gr.Warning(f"{e}")
counter = counter + 1
keep_loop = counter < number_retries
except asyncio.TimeoutError as e:
gr.Warning("Timeout!")
return None