-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathAbideSubcortical_spd_manifold.py
130 lines (101 loc) · 4.09 KB
/
AbideSubcortical_spd_manifold.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
Operations on the manifold of SPD matrices and mapping to a flat space.
"""
import numpy as np
from scipy import linalg
def frobenius(mat):
""" Return the Frobenius norm
"""
return np.sqrt((mat**2).sum())/mat.size
def sqrtm(mat):
""" Matrix square-root, for symetric positive definite matrices.
"""
vals, vecs = linalg.eigh(mat)
return np.dot(vecs*np.sqrt(vals), vecs.T)
def inv_sqrtm(mat):
""" Inverse of matrix square-root, for symetric positive definite matrices.
"""
vals, vecs = linalg.eigh(mat)
return np.dot(vecs/np.sqrt(vals), vecs.T)
def expm(mat):
""" Matrix exponential, for symetric positive definite matrices.
"""
vals, vecs = linalg.eigh(mat)
return np.dot(vecs*np.exp(vals), vecs.T)
def logm(mat):
""" Matrix log, for symetric positive definite matrices.
"""
vals, vecs = linalg.eigh(mat)
return np.dot(vecs*np.log(vals), vecs.T)
def log_map(x, displacement, mean=False):
""" The Riemannian log map at point 'displacement'.
If several points are given, the mean is returned.
See algorithm 2 of Fletcher and Joshi, Sig Proc 87 (2007) 250
"""
x = np.asanyarray(x)
vals, vecs = linalg.eigh(displacement)
sqrt_displacement = np.dot(vecs*np.sqrt(vals), vecs.T)
whitening = np.dot(vecs/np.sqrt(vals), vecs.T)
if len(x.shape) == 2:
log_x = logm(np.dot(np.dot(whitening, x), whitening))
return np.dot(np.dot(sqrt_displacement, x), sqrt_displacement)
log_x = [logm(np.dot(np.dot(whitening, m), whitening)) for m in x]
if mean:
x = np.mean(log_x, axis=0)
return np.dot(np.dot(sqrt_displacement, x), sqrt_displacement)
return [np.dot(np.dot(sqrt_displacement, x), sqrt_displacement)
for x in log_x]
def exp_map(x, displacement):
""" The Riemannian exp map at point 'displacement'.
See algorithm 1 of Fletcher and Joshi, Sig Proc 87 (2007) 250
"""
vals, vecs = linalg.eigh(displacement)
sqrt_displacement = np.dot(vecs*np.sqrt(vals), vecs.T)
whitening = np.dot(vecs/np.sqrt(vals), vecs.T)
return np.dot(np.dot(sqrt_displacement,
expm(
np.dot(np.dot(whitening, x), whitening)
)),
sqrt_displacement)
def log_mean(population_covs, eps=1e-6):
""" Find the Riemannien mean of the the covariances.
See algorithm 3 of Fletcher and Joshi, Sig Proc 87 (2007) 250
"""
step = 1
mean = np.mean(population_covs, axis=0)
direction = log_map(population_covs, mean, mean=True)
while frobenius(direction) > eps:
mean = exp_map(step*direction, mean)
new_direction = log_map(population_covs, mean, mean=True)
if frobenius(new_direction) < frobenius(direction):
direction = new_direction
else:
step = .5*step
return mean
def projection(subject_cov, population_covs, whitening=None):
if whitening is None:
whitening = inv_sqrtm(population_covs.mean(axis=0))
if len(subject_cov.shape)==3:
return np.array([ np.dot(np.dot(whitening, s), whitening)
for s in subject_cov ])
return np.dot(np.dot(whitening, subject_cov), whitening)
def riemannian_projection(subject_cov, population_covs, whitening=None):
pop_mean = log_mean(population_covs)
if len(subject_cov.shape)==3:
return np.array([log_map(s, pop_mean) for s in subject_cov])
return log_map(subject_cov, pop_mean)
def sym_to_vec(sym):
sym = np.copy(sym)
# the sqrt(2) factor
sym *= np.sqrt(2)
sym += (1 - np.sqrt(2))/np.sqrt(2)*np.diag(np.diag(sym))
mask = np.tril(np.ones(sym.shape[-2:])).astype(np.bool)
return sym[..., mask]
def vec_to_sym(vec, shape):
mask = np.tril(np.ones(shape)).astype(np.bool)
sym = np.zeros(vec.shape[:-1] + mask.shape, vec.dtype)
sym[..., mask] = vec
sym -= (1 - np.sqrt(2))*np.diag(np.diag(sym))
sym /= np.sqrt(2)
sym += np.tril(sym, k=-1).T
return sym