-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
94 lines (72 loc) · 2.51 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from time import sleep
from fastapi import FastAPI, WebSocket, BackgroundTasks
from fastapi.middleware.cors import CORSMiddleware
from utils.promptconfig import PromptConfig
from utils.template_engine import TemplateEngine
from utils.experiments import Experiments
from utils.analyser import ConfigurationAnalyser
from utils.results import Results
from utils.db_connector import DBConnector
from uuid import uuid4
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
tasks = {}
db_instance = DBConnector()
@app.get("/")
async def root():
return {"message": "Hello, World!"}
@app.post("/experiment/init")
async def init_experiment(data: dict, background_tasks: BackgroundTasks):
# Validate the data
analyser = ConfigurationAnalyser(data)
validation_message, validation_result = analyser.validate_data()
if not validation_result:
return {"error": validation_message}
# genereate a unique id for the project
project_id = str(uuid4())
# initdb and initexperiment
exp = Experiments(project_id, data)
background_tasks.add_task(exp.init)
return {"message": "Experiment initialized!", "exp_id": project_id}
@app.websocket("/experiment/status/{exp_id}")
async def experiment_status(websocket: WebSocket, exp_id: str):
websocket.accept()
if exp_id not in tasks:
websocket.send_json({"error": "Experiment not found"})
return
task = tasks[exp_id]
while task.is_alive():
websocket.send_json({"status": "running"})
sleep(1)
task.join()
websocket.send_json({"status": "completed"})
@app.get("/experiment/results/{exp_id}")
async def experiment_results(exp_id: str):
if not db_instance.is_project_exists(exp_id):
return {"error": "Experiment not found"}
r = Results(exp_id)
return r.get_results()
@app.get("/experiment/listall")
async def experiment_listall():
return db_instance.get_projects()
@app.get("/test/promptTemplate/{exp_id}")
async def test_prompt_template(exp_id: str):
app.state.prompt = PromptConfig(app.state.config, app.state.dataset)
return {"prompt": app.state.prompt}
@app.get("/test/promptTokens/{exp_id}")
async def test_prompt_tokens(exp_id: str):
print(app.state.tokens)
return {"tokens": app.state.tokens}
if __name__ == "__main__":
import uvicorn
uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)