forked from buzem/inzpeech
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclosest_celeb.py
36 lines (31 loc) · 1.26 KB
/
closest_celeb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import os
import fnmatch
import pickle5 as pickle
import numpy as np
from scipy.spatial.distance import cdist
def load_embeds(embeding_main_path):
pkl_matches = []
for root, dirname, filenames in os.walk(embeding_main_path):
for filename in fnmatch.filter(filenames, '*.pkl'):
pkl_matches.append(os.path.join(root, filename))
return pkl_matches
class NearestNeighboor(object):
@classmethod
def init_neighbor(cls, data_path):
pkl_paths = load_embeds(data_path)
cls.labels = []
#self.embeds = np.zeros((len(pkl_paths), 256))
cls.embeds = np.zeros((len(pkl_paths), 256))
for i, p in enumerate(pkl_paths):
print("Progress: ", i , " / ", len(pkl_paths), end='\r')
with open(p, 'rb') as pfile:
loaded_pkl = pickle.load(pfile)
cls.labels.append(loaded_pkl[1])
cls.embeds[i] = loaded_pkl[0]
@classmethod
def closest_labels(cls, test_sample, k):
# Get euclidean distances as 2D array
dist = cdist(cls.embeds, test_sample, 'sqeuclidean').reshape(-1)
# Find the k smallest distances
indx = np.argpartition(dist, k)[: k]
return np.unique(np.array(cls.labels)[indx])