-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathbackend.py
93 lines (72 loc) · 2.7 KB
/
backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import json
import os
import random
import numpy as np
import time
import redis
import msgpack
def read_data(filename='dict/parsed_wiki_fr2sp.json'):
db = redis.Redis(host='redis')
pack = db.get(filename)
if pack is not None:
return msgpack.unpackb(pack)
with open(filename, 'r') as f:
data = json.load(f)
for word_dict in data.values():
if 'frequency' not in word_dict:
word_dict['frequency'] = 1
db.set(filename, msgpack.packb(data, use_bin_type=True))
return data
def generate_words(data, user_past, p_new_word=0.7, num_words=100):
assert min(max(p_new_word, 0), 1) == p_new_word, (
f'p_new_word ({p_new_word}) should be in [0, 1]')
# unseen words initialized as all words
unseen = set(data.keys())
seen_incorrect = set()
for word, word_data in user_past.items():
unseen.discard(word)
# update with user's last guess
user_was_incorrect = not word_data['correct'][-1]
if user_was_incorrect:
seen_incorrect.add(word)
proba = [data[word]['frequency'] for word in unseen]
normalization = sum(proba)
proba = [x / normalization for x in proba]
# by default all samples come from the unseen list
samples = list(np.random.choice(
list(unseen), size=num_words, replace=False, p=proba))
# randomly replace some elements with words from seen_incorrect
for i in range(num_words):
if not seen_incorrect:
break
if random.uniform(0, 1) >= p_new_word:
incorrect_word = random.choice(list(seen_incorrect))
seen_incorrect.discard(incorrect_word)
samples[i] = incorrect_word
return samples
class User(object):
def __init__(self, user_filename='default_user_data.json'):
self._user_filename = user_filename
self._db = redis.Redis(host='redis')
self.load_past()
def log_entry(self, word, result):
if word not in self.past:
self.past[word] = {
'timestamp': [time.time()],
'correct': [result]}
else:
self.past[word]['timestamp'].append(time.time())
self.past[word]['correct'].append(result)
def load_past(self):
pack = self._db.get('user')
if pack is not None:
self.past = msgpack.unpackb(pack)
elif not os.path.exists(self._user_filename):
self.past = dict()
else:
with open(self._user_filename, 'r') as f:
self.past = json.load(f)
def save_past(self):
with open(self._user_filename, 'w') as f:
json.dump(self.past, f)
self._db.set('user', msgpack.packb(self.past, use_bin_type=True))