diff --git a/clustinator/clustering.py b/clustinator/clustering.py index ed6c498..36be37b 100644 --- a/clustinator/clustering.py +++ b/clustinator/clustering.py @@ -4,6 +4,7 @@ from sklearn.cluster import DBSCAN import numpy as np +import math # Data imports PATH = "../data/raw/" @@ -45,7 +46,17 @@ def cluster_means(self): for label in np.unique(labels): chains_with_label = self.X[labels == label] mean_chain = sum(chains_with_label) / chains_with_label.shape[0] - # TODO: normalize rows to 1 (if a session misses an endpoint, that row will be 0 --> mean can be less than 1) - cluster_mean_dict[str(label)] = mean_chain.toarray()[0].tolist() + mean_chain = mean_chain.toarray()[0].tolist() + row_length = int(round(math.sqrt(len(mean_chain)))) + + for i in range(len(mean_chain) // row_length): + row_sum = sum(mean_chain[i * row_length:(i+1) * row_length]) + + if row_sum > 0 and abs(row_sum - 1.0) > 0.001: + # if a session misses an endpoint, that row will be 0 --> mean can be less than 1 + for j in range(i * row_length, (i+1) * row_length): + mean_chain[j] = mean_chain[j] / row_sum + + cluster_mean_dict[str(label)] = mean_chain return cluster_mean_dict