-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathapp.py
71 lines (52 loc) · 1.67 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from train import Model
import torch
from flask import Flask
from flask_compress import Compress
from flask.helpers import make_response
import numpy as np
import sys
import base64
# load vecs
dataset_name = 'vh.' + sys.argv[1]
train_data = torch.load(dataset_name + '/train.pth')
# (n_chunks, n_frames) distances
vecs = np.array([data[1].tolist() for data in train_data]).transpose()
# instantiate model
n_sensors = train_data[0][0].shape[0]
n_chunks = train_data[0][1].shape[0]
model = Model(n_sensors, n_chunks)
# load model
ckpt = torch.load(dataset_name + '/.pth')
model.load_state_dict(ckpt)
model.eval()
def get_frame_ids(sensors):
pred_vecs = model(sensors).detach().numpy()
pred_frame_ids = [np.argmin(abs(vec - pred_vec)) \
for vec, pred_vec in zip(vecs, pred_vecs)]
return pred_frame_ids
server = Flask(__name__)
server.config.update(
ENV = 'development',
DEBUG = True
)
Compress(server)
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
@static_vars(prev_frame_ids = [0] * n_chunks)
@server.route('/', methods = ['GET'])
def getCloud():
sensors = train_data[np.random.randint(100)][0]
pred_frame_ids = get_frame_ids(sensors)
response = {}
for chunk_id, (prev_frame_id, pred_frame_id) in enumerate(zip(getCloud.prev_frame_ids, pred_frame_ids)):
if pred_frame_id != prev_frame_id:
chunk_points = np.load(dataset_name + '/chunk/%d-%d.npz' % (pred_frame_id, chunk_id))['arr_0']
response[chunk_id] = chunk_points.tobytes()
getCloud.prev_frame_ids = pred_frame_ids
print(len(response.keys()) + ' chunks updated')
return response
server.run(debug = True)