-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
126 lines (96 loc) · 4.32 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import json
import os
import random
import sys
from os.path import exists
from pathlib import Path
from time import sleep
from fastapi import FastAPI, UploadFile, File
from starlette.responses import PlainTextResponse
from data.ExtractionData import ExtractionData
from data.LabeledData import LabeledData
from data.Options import Options
from data.PredictionData import PredictionData
from data.SegmentBox import SegmentBox
from data.Suggestion import Suggestion
app = FastAPI()
data_path = Path("data.json")
params_path = Path("params.json")
options_path = Path("options.json")
@app.get("/info")
async def info():
return sys.version
@app.post("/async_extraction/{tenant}")
async def async_extraction(tenant, file: UploadFile = File(...)):
return "task registered"
@app.post("/set_paragraphs")
async def set_paragraphs(extraction_data: ExtractionData):
return "paragraphs saved"
@app.get("/get_paragraphs/{tenant}/{pdf_file_name}")
async def get_paragraphs(tenant: str, pdf_file_name: str):
print("get_paragraphs", tenant, pdf_file_name)
extraction_data = ExtractionData(tenant=tenant,
file_name=pdf_file_name,
paragraphs=[SegmentBox()],
page_height=0,
page_width=0)
return extraction_data.model_dump_json()
@app.get("/get_xml/{tenant}/{pdf_file_name}", response_class=PlainTextResponse)
async def get_xml():
with open(f"test.xml", mode="r") as file:
content = file.read()
return content
@app.post("/xml_to_train/{tenant}/{extractor_id}")
async def to_train_xml_file(tenant, extractor_id, file: UploadFile = File(...)):
print("received file to train", tenant, extractor_id)
return "xml_to_train saved"
@app.post("/xml_to_predict/{tenant}/{extractor_id}")
async def to_predict_xml_file(tenant, extractor_id, file: UploadFile = File(...)):
print("received file to predict", tenant, extractor_id)
return "xml_to_train saved"
@app.post("/labeled_data")
async def labeled_data_post(labeled_data: LabeledData):
return "labeled data saved"
@app.post("/prediction_data")
async def prediction_data_post(prediction_data: PredictionData):
predictions_data = json.loads(data_path.read_text()) if exists(data_path) else list()
predictions_data.append(prediction_data.model_dump())
data_path.write_text(json.dumps(predictions_data))
return "prediction data saved"
@app.get("/get_suggestions/{tenant}/{extractor_id}")
async def get_suggestions(tenant: str, extractor_id: str):
predictions_data = json.loads(data_path.read_text()) if exists(data_path) else list()
suggestions_list = list()
params = json.loads(params_path.read_text()) if exists(params_path) else dict()
params["options"] = params["options"] if params and "options" in params else list()
if options_path.exists() and not params["options"]:
params["options"] = json.loads(options_path.read_text())
all_values = params["options"] if params and "options" in params else list()
multi_value = params["multi_value"] if params and "multi_value" in params else False
for prediction_data in predictions_data:
values_count = random.randint(1, len(all_values)) if multi_value else 1
values = random.sample(all_values, k=values_count) if all_values else list()
suggestions_list.append(
Suggestion(
tenant=tenant,
id=extractor_id,
xml_file_name=prediction_data["xml_file_name"],
entity_name=prediction_data["entity_name"],
text="2023" if not values else ' '.join([option["label"] for option in values]),
values=values,
segment_text="2023" if not values else ' '.join([option["label"] for option in values]),
page_number=1,
segments_boxes=[SegmentBox(left=0, top=0, width=250, height=250, page_number=1)],
).model_dump()
)
if exists(data_path):
os.remove(data_path)
if exists(params_path):
os.remove(params_path)
sleep(5)
return json.dumps(suggestions_list)
@app.post("/options")
def save_options(options: Options):
options_list = [option.model_dump() for option in options.options]
options_path.write_text(json.dumps(options_list))
return True