-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsimple_html_to_text_classifier.py
181 lines (156 loc) · 6.89 KB
/
simple_html_to_text_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import pickle
import os
import numpy as np
import requests
import json
from uuid import uuid4
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import make_pipeline
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import DATA_UNDEFINED_NAME, get_env
HOSTNAME = get_env("HOSTNAME", "http://localhost:8080")
API_KEY = get_env("API_KEY")
print("=> LABEL STUDIO HOSTNAME = ", HOSTNAME)
if not API_KEY:
print("=> WARNING! API_KEY is not set")
class SimpleHTMLToTextClassifier(LabelStudioMLBase):
def __init__(self, **kwargs):
# don't forget to initialize base class...
super(SimpleHTMLToTextClassifier, self).__init__(**kwargs)
# then collect all keys from config which will be used to extract data from task and to form prediction
# Parsed label config contains only one output of <Choices> type
assert len(self.parsed_label_config) == 1
self.from_name, self.info = list(self.parsed_label_config.items())[0]
assert self.info["type"] == "Choices"
# the model has only one textual input
assert len(self.info["to_name"]) == 1
assert len(self.info["inputs"]) == 1
assert self.info["inputs"][0]["type"] == "Text"
self.to_name = self.info["to_name"][0]
self.value = self.info["inputs"][0]["value"]
if not self.train_output:
# If there is no trainings, define cold-started the simple TF-IDF text classifier
self.reset_model()
# This is an array of <Choice> labels
self.labels = self.info["labels"]
# make some dummy initialization
self.model.fit(X=self.labels, y=list(range(len(self.labels))))
print(
"Initialized with from_name={from_name}, to_name={to_name}, labels={labels}".format(
from_name=self.from_name,
to_name=self.to_name,
labels=str(self.labels),
)
)
else:
# otherwise load the model from the latest training results
self.model_file = self.train_output["model_file"]
with open(self.model_file, mode="rb") as f:
self.model = pickle.load(f)
# and use the labels from training outputs
self.labels = self.train_output["labels"]
print(
"Loaded from train output with from_name={from_name}, to_name={to_name}, labels={labels}".format(
from_name=self.from_name,
to_name=self.to_name,
labels=str(self.labels),
)
)
def reset_model(self):
self.model = make_pipeline(
TfidfVectorizer(ngram_range=(1, 3), token_pattern=r"(?u)\b\w\w+\b|\w"),
LogisticRegression(C=10, verbose=True),
)
def predict(self, tasks, **kwargs):
# collect input texts
input_texts = []
for task in tasks:
input_text = task["data"].get(self.value) or task["data"].get(
DATA_UNDEFINED_NAME
)
input_texts.append(input_text)
# get model predictions
probabilities = self.model.predict_proba(input_texts)
print("=== probabilities >", probabilities)
predicted_label_indices = np.argmax(probabilities, axis=1)
predicted_scores = probabilities[
np.arange(len(predicted_label_indices)), predicted_label_indices
]
predictions = []
for idx, score in zip(predicted_label_indices, predicted_scores):
predicted_label = self.labels[idx]
# prediction result for the single task
result = [
{
"from_name": self.from_name,
"to_name": self.to_name,
"type": "choices",
"value": {"choices": [predicted_label]},
}
]
# expand predictions with their scores for all tasks
predictions.append({"result": result, "score": score})
return predictions
def _get_annotated_dataset(self, project_id):
"""Just for demo purposes: retrieve annotated data from Label Studio API"""
download_url = f'{HOSTNAME.rstrip("/")}/api/projects/{project_id}/export'
response = requests.get(
download_url, headers={"Authorization": f"Token {API_KEY}"}
)
if response.status_code != 200:
raise Exception(
f"Can't load task data using {download_url}, "
f"response status_code = {response.status_code}"
)
return json.loads(response.content)
def fit(self, annotations, workdir=None, **kwargs):
# check if training is from web hook
if kwargs.get("data"):
project_id = kwargs["data"]["project"]["id"]
tasks = self._get_annotated_dataset(project_id)
# ML training without web hook
else:
tasks = annotations
input_texts = []
output_labels, output_labels_idx = [], []
label2idx = {l: i for i, l in enumerate(self.labels)}
for task in tasks:
if not task.get("annotations"):
continue
annotation = task["annotations"][0]
# get input text from task data
if annotation.get("skipped") or annotation.get("was_cancelled"):
continue
input_text = task["data"].get(self.value) or task["data"].get(
DATA_UNDEFINED_NAME
)
print(input_text)
input_texts.append(input_text)
# get an annotation
output_label = annotation["result"][0]["value"]["choices"][0]
output_labels.append(output_label)
output_label_idx = label2idx[output_label]
output_labels_idx.append(output_label_idx)
new_labels = set(output_labels)
if len(new_labels) != len(self.labels):
self.labels = list(sorted(new_labels))
print("Label set has been changed:" + str(self.labels))
label2idx = {l: i for i, l in enumerate(self.labels)}
output_labels_idx = [label2idx[label] for label in output_labels]
# train the model
print(f"Start training on {len(input_texts)} samples")
self.reset_model()
self.model.fit(input_texts, output_labels_idx)
# save output resources
workdir = workdir or os.getenv("MODEL_DIR")
model_name = str(uuid4())[:8]
if workdir:
model_file = os.path.join(workdir, f"{model_name}.pkl")
else:
model_file = f"{model_name}.pkl"
print(f"Save model to {model_file}")
with open(model_file, mode="wb") as fout:
pickle.dump(self.model, fout)
train_output = {"labels": self.labels, "model_file": model_file}
return train_output