forked from Artisan-Lab/SMTimer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathKNN_Predictor.py
62 lines (56 loc) · 2.15 KB
/
KNN_Predictor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
import json
from .dgl_treelstm.KNN import KNN
import numpy as np
from .preprocessing import Vector_Dataset
class Predictor:
model = None
dataset = Vector_Dataset(feature_number_limit=2)
filename = None
load_file = None
timeout_threshold = 200
def __init__(self, filename, timeout_threshold=200, load_file="gnucore"):
Predictor.load_file = load_file
Predictor.filename = filename
Predictor.timeout_threshold = timeout_threshold
self.init_static()
self.remove_name = False
self.x = np.zeros((1, 300))
@staticmethod
def init_static():
base_dir = os.path.dirname(os.path.abspath(__file__))
try:
with open(base_dir + "/KNN_training_data/" + Predictor.load_file, "r") as f:
train_dataset = json.load(f)
except (IOError, ValueError):
with open(base_dir + "/KNN_training_data/gnucore.json", "r") as f:
train_dataset = json.load(f)
Predictor.model = KNN(k=3)
y_train = np.array([1 if i > Predictor.timeout_threshold else 0 for i in train_dataset["adjust"]])
x_train = np.array(train_dataset["x"])
Predictor.model.fit(x_train, y_train)
try:
Predictor.model.filename = np.array(train_dataset["filename"])
Predictor.model.remove_test(Predictor.filename)
except (KeyError):
pass
def predict(self, script):
if not Predictor.model:
Predictor.init_static()
if Predictor.filename != "" and self.remove_name == False:
try:
Predictor.model.remove_test(Predictor.filename)
except (KeyError):
pass
self.remove_name = True
model = Predictor.model
try:
dataset = Predictor.dataset.generate_feature_dataset([script], time_selection="z3")
except (KeyError,IndexError) as e:
print(e)
return 0
self.x = np.array(dataset[-1].feature).reshape(-1, 300)
pred = model.predict(self.x)[0]
return pred
def increment_KNN_data(self, truth):
Predictor.model.incremental(self.x, truth)