-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathserver.py
84 lines (70 loc) · 2.31 KB
/
server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# AIM OF THE FILE
"""
Give an API request to all the backend requests
"""
from fastapi import FastAPI
from src.vectorDatabase import create_database
from src.queryDatabase import query_database_earnings_call, query_database_sec
import os
from src.chat_earnings_call import get_openai_answer_earnings_call
from src.chat_sec import get_openai_answer_sec
from dotenv import load_dotenv
import openai
try:
load_dotenv()
except:
pass
openai.api_key = os.environ["OPENAI_API_KEY"]
app = FastAPI()
@app.get("/data/{ticker}/{year}")
async def ticker_year(ticker: str, year: int):
# print(ticker, year)
(
qdrant_client_,
encoder_,
speakers_list_1_,
speakers_list_2_,
speakers_list_3_,
speakers_list_4_,
sec_form_names_,
earnings_call_quarter_vals_,
) = create_database(ticker=ticker, year=year)
global qdrant_client
qdrant_client = qdrant_client_
global encoder
encoder = encoder_
global speakers_list_1
speakers_list_1 = speakers_list_1_
global speakers_list_2
speakers_list_2 = speakers_list_2_
global speakers_list_3
speakers_list_3 = speakers_list_3_
global speakers_list_4
speakers_list_4 = speakers_list_4_
global sec_form_names
sec_form_names = sec_form_names_
global earnings_call_quarter_vals
earnings_call_quarter_vals = earnings_call_quarter_vals_
@app.get("/Earnings/{question}/{quarter}")
async def earnings_chat(question: str, quarter: str):
if quarter == "Q1":
speakers_list = speakers_list_1
elif quarter == "Q2":
speakers_list = speakers_list_2
elif quarter == "Q3":
speakers_list = speakers_list_3
elif quarter == "Q4":
speakers_list = speakers_list_4
relevant_text = query_database_earnings_call(
question, quarter, qdrant_client, encoder, speakers_list
)
res = get_openai_answer_earnings_call(question, relevant_text)
return res, relevant_text
@app.get("/SEC/{question}/{doc_name}")
async def sec_chat(question: str, doc_name: str):
assert (
doc_name in sec_form_names
), f"The document name should be in the list {sec_form_names}"
relevant_text = query_database_sec(question, qdrant_client, encoder, doc_name)
res = get_openai_answer_sec(question, relevant_text)
return res, relevant_text