-
Notifications
You must be signed in to change notification settings - Fork 30
/
Copy pathnative_client.py
executable file
·351 lines (288 loc) · 12.8 KB
/
native_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
#!/usr/bin/env python3
'''A native client simulating the plugin to use for testing the server'''
import asyncio
import itertools
import struct
import json
import time
import sys
import csv
from pathlib import Path
from pprint import pprint
from tqdm import tqdm
class Timer:
"""Little helper class top measure runtime of async function calls and dump
all of those to a CSV.
"""
def __init__(self):
self.measurements = []
async def measure(self, coro, *details):
start = time.perf_counter()
result = await coro
end = time.perf_counter()
self.measurements.append([end - start, *details])
return result
def dump(self, fh):
# TODO stats? For now I just export to Excel or something
writer = csv.writer(fh)
writer.writerows(self.measurements)
class Client:
"""asyncio based native messaging client. Main interface is just calling
`request()` with the right parameters and awaiting the future it returns.
"""
def __init__(self, *args):
self.serial = itertools.count(1)
self.futures = {}
self.args = args
async def __aenter__(self):
self.proc = await asyncio.create_subprocess_exec(*self.args, stdin=asyncio.subprocess.PIPE, stdout=asyncio.subprocess.PIPE)
self.read_task = asyncio.create_task(self.reader())
return self
async def __aexit__(self, *args):
self.proc.stdin.close()
await self.proc.wait()
def request(self, command, data, *, update=lambda data: None):
message_id = next(self.serial)
message = json.dumps({"command": command, "id": message_id, "data": data}).encode()
# print(f"Sending: {message}", file=sys.stderr)
future = asyncio.get_running_loop().create_future()
self.futures[message_id] = future, update
self.proc.stdin.write(struct.pack("@I", len(message)))
self.proc.stdin.write(message)
return future
async def reader(self):
while True:
try:
raw_length = await self.proc.stdout.readexactly(4)
length = struct.unpack("@I", raw_length)[0]
raw_message = await self.proc.stdout.readexactly(length)
# print(f"Receiving: {raw_message.decode()}", file=sys.stderr)
message = json.loads(raw_message)
# Not cool if there is no response message "id" here
if not "id" in message:
continue
# print(f"Receiving response to {message['id']}", file=sys.stderr)
future, update = self.futures[message["id"]]
if "success" in message:
del self.futures[message["id"]]
if message["success"]:
future.set_result(message["data"])
else:
future.set_exception(Exception(message["error"]))
elif "update" in message:
update(message["data"])
except asyncio.IncompleteReadError:
break # Stop read loop if EOF is reached
except asyncio.CancelledError:
break # Also stop reading if we're cancelled
class TranslateLocally(Client):
"""TranslateLocally wrapper around Client that translates
our defined messages into functions with arguments.
"""
async def list_models(self, *, include_remote=False):
return await self.request("ListModels", {"includeRemote": bool(include_remote)})
async def translate(self, text, src=None, trg=None, *, model=None, pivot=None, html=False):
if src and trg:
if model or pivot:
raise InvalidArgumentException("Cannot combine src + trg and model + pivot arguments")
spec = {"src": str(src), "trg": str(trg)}
elif model:
if pivot:
spec = {"model": str(model), "pivot": str(pivot)}
else:
spec = {"model": str(model)}
else:
raise InvalidArgumentException("Missing src + trg or model argument")
result = await self.request("Translate", {**spec, "text": str(text), "html": bool(html)})
return result["target"]["text"]
async def download_model(self, model_id, *, update=lambda data: None):
return await self.request("DownloadModel", {"modelID": str(model_id)}, update=update)
def first(iterable, *default):
"""Returns the first value of anything iterable, or throws StopIteration
if it is empty. Or, if you specify a default argument, it will return that.
"""
return next(iter(iterable), *default) # passing as rest argument so it can be nothing and trigger StopIteration exception
def get_build():
"""Instantiate an asyncio TranslateLocally client that connects to
tranlateLocally in your local build directory.
"""
paths = [
Path("./translateLocally"),
Path(__file__).resolve().parent / Path("../build/translateLocally")
];
for path in paths:
if path.exists():
return TranslateLocally(path.resolve(), "-p", "--debug")
raise RuntimeError("Could not find translateLocally binary")
async def download_with_progress(tl, model, position):
"""tl.download but with a tqdm powered progress bar."""
with tqdm(position=position, desc=model["modelName"], unit="b", unit_scale=True, leave=False) as bar:
def update(data):
assert data["read"] <= data["size"]
bar.total = data["size"]
diff = data["read"] - bar.n
bar.update(diff)
return await tl.download_model(model["id"], update=update)
async def test():
"""Test TranslateLocally functionality."""
async with get_build() as tl:
models = await tl.list_models(include_remote=True)
pprint(models)
# Models necessary for tests, both direct & pivot
necessary_models = {("en", "de"), ("en", "es"), ("es", "en")}
# From all models available, pick one for every necessary language pair
# (preferring tiny ones) so we can make sure these are downloaded.
selected_models = {
(src,trg): first(sorted(
(
model
for model in models
if src in model["srcTags"] and trg == model["trgTag"]
),
key=lambda model: 0 if model["type"] == "tiny" else 1
))
for src, trg in necessary_models
}
pprint(selected_models)
# Download them. Even if they're already model['local'] == True, to test
# that in that case this is a no-op.
await asyncio.gather(*(
download_with_progress(tl, model, position)
for position, model in enumerate(selected_models.values())
))
print() # tqdm messes a lot with the print position, this makes it less bad
# Test whether the model list has been updated to reflect that the
# downloaded models are now local.
models = await tl.list_models(include_remote=True)
assert all(
model["local"]
for selected_model in selected_models.values()
for model in models
if model["id"] == selected_model["id"]
)
# Perform some translations, switching between the models
translations = await asyncio.gather(
tl.translate("Hello world!", "en", "de"),
tl.translate("Let's translate another sentence to German.", "en", "de"),
tl.translate("Sticks and stones may break my bones but words WILL NEVER HURT ME!", "en", "es"),
tl.translate("I <i>like</i> to drive my car. But I don't have one.", "en", "de", html=True),
tl.translate("¿Por qué no funciona bien?", "es", "de"),
tl.translate("This will be the last sentence of the day.", "en", "de"),
)
pprint(translations)
assert translations == [
"Hallo Welt!",
"Übersetzen wir einen weiteren Satz mit Deutsch.",
"Palos y piedras pueden romper mis huesos, pero las palabras NUNCA HURT ME.",
"Ich <i>fahre gerne</i> mein Auto. Aber ich habe keine.", #<i>fahre</i>???
"Warum funktioniert es nicht gut?",
"Dies wird der letzte Satz des Tages sein.",
]
# Test bad input
try:
await tl.translate("This is impossible to translate", "en", "xx")
assert False, "How are we able to translate to 'xx'???"
except Exception as e:
assert "Could not find the necessary translation models" in str(e)
print("Fin")
async def test_third_party():
"""Test whether TranslateLocally can switch between different types of
models. This test assumes you have the OPUS repository in your list:
https://object.pouta.csc.fi/OPUS-MT-models/app/models.json
"""
async with get_build() as tl:
models_to_try = [
'en-de-tiny',
'en-de-base',
'eng-fin-tiny', # model has broken model_info.json so won't work anyway :(
'eng-ukr-tiny',
]
models = await tl.list_models(include_remote=True)
# Select a model from the model list for each of models_to_try, but
# leave it out if there is no model available.
selected_models = {
shortname: model
for shortname in models_to_try
if (model := first((model for model in models if model["shortname"] == shortname), None))
}
await asyncio.gather(*(
download_with_progress(tl, model, position)
for position, model in enumerate(selected_models.values())
))
# TODO: Temporary filter to figure out 'failed' downloads. eng-fin-tiny
# has a broken JSON file so it will download correctly, but still not
# be available or show up in this list. We should probably make the
# download fail in that scenario.
models = await tl.list_models(include_remote=False)
for shortname in list(selected_models.keys()):
if not any(True for model in models if model["shortname"] == shortname):
print(f"Skipping {shortname} because it didn't show up in model list after downloading", file=sys.stderr)
del selected_models[shortname]
translations = await asyncio.gather(*[
tl.translate("This is a very simple test sentence", model=model["id"])
for model in selected_models.values()
])
pprint(list(zip(selected_models.keys(), translations)))
async def test_latency():
timer = Timer()
# Our line generator: just read Crime & Punishment from stdin :D
lines = (line.strip() for line in sys.stdin)
async with get_build() as tl:
for epoch in range(100):
print(f"Epoch {epoch}...", file=sys.stderr)
for batch_size in [1, 5, 10, 20, 50, 100]:
await asyncio.gather(*(
timer.measure(
tl.translate(line, "en", "de"),
epoch,
batch_size,
len(line.split(' ')))
for n, line in zip(range(batch_size), lines)
))
timer.dump(sys.stdout)
async def test_concurrency():
async with get_build() as tl:
fetch_one = tl.list_models(include_remote=True)
fetch_two = tl.list_models(include_remote=False)
fetch_three = tl.list_models(include_remote=True)
await asyncio.gather(fetch_one, fetch_two, fetch_three)
async def test_shutdown():
tasks = []
async with get_build() as tl:
for n in range(10):
print(f"Requesting translation {n}")
tasks.append(tl.request("Translate", {
"src": "en",
"trg": "de",
"text": f"This is simple sentence number {n}!",
"html": False
}))
print("Shutting down")
print("Shutdown complete")
for translation in asyncio.as_completed(tasks):
print(await translation)
print("Fin.")
async def test_concurrent_download():
"""Test parallel downloads."""
async with get_build() as tl:
models = await tl.list_models(include_remote=True)
remote = [model for model in models if not model["local"]]
downloads = [
tl.download_model(model["id"])
for model, _ in zip(remote, range(3))
]
await asyncio.gather(*downloads)
def main():
tests = {
"test": test,
"third-party": test_third_party,
"latency": test_latency,
"concurrency": test_concurrency,
"shutdown": test_shutdown,
"concurrent-downloads": test_concurrent_download
}
if len(sys.argv) == 1 or sys.argv[1] not in tests:
print(f"Usage: {sys.argv[0]} {' | '.join(tests.keys())}", file=sys.stderr)
else:
asyncio.run(tests[sys.argv[1]]())
main()