-
Notifications
You must be signed in to change notification settings - Fork 156
/
Copy pathws_service.py
157 lines (136 loc) · 6.72 KB
/
ws_service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
from datetime import datetime
import websockets
import asyncio
import os
import uuid
import json
import functools
import traceback
import sys
import logging
from multiprocessing import current_process, Process, Queue, queues
from common import MessageType, format_message, timestamp
import startup
user_dict = {}
KEY_TO_USE_DEFAULT = os.getenv("KEY_TO_USE_DEFAULT")
DEFAULT_LLM_API_KEY = os.getenv("DEFAULT_LLM_API_KEY") if KEY_TO_USE_DEFAULT is not None else None
DEFAULT_SERP_API_KEY = os.getenv("DEFAULT_SERP_API_KEY") if KEY_TO_USE_DEFAULT is not None else None
logging.basicConfig(level=logging.WARNING, format='%(asctime)s | %(levelname)-8s | %(module)s:%(funcName)s:%(lineno)d - %(message)s')
logger = logging.getLogger(__name__)
async def handle_message(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
if "llm_api_key" in message["data"] and len(message["data"]["llm_api_key"].strip()) >= 32:
llm_api_key = message["data"]["llm_api_key"].strip()
if KEY_TO_USE_DEFAULT is not None and \
DEFAULT_LLM_API_KEY is not None and \
llm_api_key == KEY_TO_USE_DEFAULT:
# replace with default key
logger.warning("Using default llm api key")
llm_api_key = DEFAULT_LLM_API_KEY
if "serpapi_key" in message["data"] and len(message["data"]["serpapi_key"].strip()) >= 32:
serpapi_key = message["data"]["serpapi_key"].strip()
if KEY_TO_USE_DEFAULT is not None and \
DEFAULT_SERP_API_KEY is not None and \
serpapi_key == KEY_TO_USE_DEFAULT:
# replace with default key
logger.warning("Using default serp api key")
serpapi_key = DEFAULT_SERP_API_KEY
idea = message["data"]["idea"].strip()
if not llm_api_key:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid OpenAI key"))
return
if not serpapi_key:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid SerpAPI key"))
return
if not idea or len(idea) < 2:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg="Invalid task idea"))
return
try:
await startup.startup(idea=idea, task_id=task_id, llm_api_key=llm_api_key, serpapi_key=serpapi_key, proxy=proxy, alg_msg_queue=alg_msg_queue)
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id':task_id}, msg="finished"))
except Exception as e:
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, msg=f"{e}"))
exc_type, exc_value, exc_traceback = sys.exc_info()
error_message = traceback.format_exception(exc_type, exc_value, exc_traceback)
logger.error("".join(error_message))
def handle_message_wrapper(task_id=None, message=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
logger.warning("New task:"+current_process().name)
asyncio.run(handle_message(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))
def clear_queue(alg_msg_queue:Queue=None):
if not Queue:
return
try:
while True:
alg_msg_queue.get_nowait()
except queues.Empty:
pass
# read websocket messages
async def read_msg_worker(websocket=None, alg_msg_queue=None, proxy=None, llm_api_key=None, serpapi_key=None):
process = None
async for raw_message in websocket:
message = json.loads(raw_message)
if message["action"] == MessageType.Interrupt.value:
# force interrupt a specific task
task_id = message["data"]["task_id"]
if process and process.is_alive() and process.name == task_id:
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
alg_msg_queue.put_nowait(format_message(action=MessageType.Interrupt.value, data={'task_id': task_id}))
alg_msg_queue.put_nowait(format_message(action=MessageType.RunTask.value, data={'task_id': task_id}, msg="finished"))
elif message["action"] == MessageType.RunTask.value:
# auto interrupt previous task
if process and process.is_alive():
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
task_id = str(uuid.uuid4())
process = Process(target=handle_message_wrapper, args=(task_id, message, alg_msg_queue, proxy, llm_api_key, serpapi_key))
process.daemon = True
process.name = task_id
process.start()
# auto terminate process
if process and process.is_alive():
logger.warning("Interrupt task:" + process.name)
process.terminate()
process = None
clear_queue(alg_msg_queue=alg_msg_queue)
raise websockets.exceptions.ConnectionClosed(0, "websocket closed")
# send
async def send_msg_worker(websocket=None, alg_msg_queue=None):
while True:
if alg_msg_queue.empty():
await asyncio.sleep(0.5)
else:
msg = alg_msg_queue.get_nowait()
print("=====Sending msg=====\n", msg)
await websocket.send(msg)
async def echo(websocket, proxy=None, llm_api_key=None, serpapi_key=None):
# audo register
uid = datetime.strftime(datetime.now(), '%Y%m%d%H%M%S.%f')+'_'+str(uuid.uuid4())
logger.warning(f"New user registered, uid: {uid}")
if uid not in user_dict:
user_dict[uid] = websocket
else:
logger.warning(f"Duplicate user, uid: {uid}")
# message handling
try:
alg_msg_queue = Queue()
await asyncio.gather(
read_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue, proxy=proxy, llm_api_key=llm_api_key, serpapi_key=serpapi_key),
send_msg_worker(websocket=websocket, alg_msg_queue=alg_msg_queue)
)
except websockets.exceptions.ConnectionClosed:
logger.warning("Websocket closed: remote endpoint going away")
finally:
asyncio.current_task().cancel()
# auto unregister
logger.warning(f"Auto unregister, uid: {uid}")
if uid in user_dict:
user_dict.pop(uid)
async def run_service(host: str = "localhost", port: int=9000, proxy: str=None, llm_api_key:str=None, serpapi_key:str=None):
message_handler = functools.partial(echo, proxy=proxy,llm_api_key=llm_api_key, serpapi_key=serpapi_key)
async with websockets.serve(message_handler, host, port):
logger.warning(f"Websocket server started: {host}:{port} {f'[proxy={proxy}]' if proxy else ''}")
await asyncio.Future()