-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
99 lines (77 loc) · 2.72 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import nltk
nltk.download("punkt")
import pickle
from fastapi import FastAPI, APIRouter
from pydantic import BaseModel
from typing import List, Dict
from fastapi import Request
from fastapi.responses import HTMLResponse
from fastapi.templating import Jinja2Templates
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import FileResponse
from fastapi.responses import JSONResponse
import json
from dataset import load_tacl_corpus, get_masked_refs
from helpers import load_model, load_best_state
from get_prediction_json import get_prediction_json
from utils import get_example_script as _get_example_script
class ModelRequest(BaseModel):
text: str
class ScriptRequest(BaseModel):
script_type: str
templates = Jinja2Templates(directory="template")
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:8000",
"http://localhost:8080",
"https://parkervg.github.io",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins, # Make sure when deployed, this isn't set to "*"
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def _load_model(model_path: str, tok2id: Dict[str, int], id2tok: Dict[str, int]):
base_model = load_model(
tok2id=tok2id, id2tok=id2tok, device="cpu", model_load_dir=model_path
)
return load_best_state(model_path, base_model)
TACL_DIR = "../data/taclData"
with open("../resources/id2tok.pkl", "rb") as f:
id2tok = pickle.load(f)
with open("../resources/tok2id.pkl", "rb") as f:
tok2id = pickle.load(f)
with open("../resources/ref_model.pkl", "rb") as f:
ref_model = pickle.load(f)
with open("../resources/coref_model.pkl", "rb") as f:
coref_model = pickle.load(f)
masked_refs = get_masked_refs(TACL_DIR)
corpus = load_tacl_corpus(TACL_DIR, masked_refs, device="cpu")
@app.get("/is_up/", response_class=HTMLResponse)
async def home():
"""
Used to tell when the API is up on Heroku.
"""
return JSONResponse({"response": True})
@app.post("/get_json_prediction/")
async def get_json_prediction(request: ModelRequest):
text = request.text
print(f"Received request: {text[:10]}")
prediction_json = get_prediction_json(
ref_model, coref_model, request.text, tok2id, id2tok
)
print(json.dumps(prediction_json, indent=4))
return JSONResponse(content=prediction_json)
@app.post("/get_example_script/")
async def get_example_script(request: ScriptRequest):
"""
Given a script type, returns a random example of a script with the masked referent left blank.
"""
script_type = request.script_type
return JSONResponse(
{"text": _get_example_script(corpus=corpus, script_type=script_type)}
)