-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathtrace_analysis_functions.py
1679 lines (1389 loc) · 58.3 KB
/
trace_analysis_functions.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""
Created by:
Kai Sandvold Beckwith
Ellenberg group
EMBL Heidelberg
Colors scheme for convenience:
tab:blue : #1f77b4
tab:orange : #ff7f0e
tab:green : #2ca02c
tab:red : #d62728
tab:purple : #9467bd
tab:brown : #8c564b
tab:pink : #e377c2
tab:gray : #7f7f7f
tab:olive : #bcbd22
tab:cyan : #17becf
"""
import os
import itertools
import logging
from math import cos, sin, radians
import re
from typing import *
import warnings
from numba import jit, njit
from numba.core.errors import NumbaPerformanceWarning
import numpy as np
#from numpy.core.fromnumeric import shape
import pandas as pd
#from plotly.offline import iplot
#import plotly.graph_objs as go
#import plotly.express as px
from scipy import interpolate
from scipy.spatial.distance import cdist, squareform, pdist
from scipy.spatial import ConvexHull
from scipy.cluster.hierarchy import linkage, dendrogram, fcluster
import matplotlib.pyplot as plt
import seaborn as sns
import tqdm
warnings.simplefilter('ignore', category=NumbaPerformanceWarning)
logger = logging.getLogger()
def pwd_calc(traces):
'''
Parameters
----------
traces : pd DataFrame with trace data.
Returns
-------
pwds : Pair-wise distance matrixes for traces as an 3D numpy array.
'''
points = points_from_traces_nan(traces, trace_ids = -1)
pwds = [cdist(p, p) for p in points]
pwds = np.stack(pwds)
return pwds
def trace_analysis(traces, pwds):
'''
Calculates pairwise trace similarity based on :
- MSE of points after rigid alignment
- PCC of points after rigid alignment
- MSE of pairwise distance matrices
- PCC of pwds.
Parameters
----------
traces : pd DataFrame with trace data.
pwds : 3D np array from pwd_calc.
Returns
-------
output : pd DataFrame with pairwise similarity results, including indexes of traces.
'''
points = np.stack(points_from_traces(traces))
trace_idx = traces.traceId.unique()
res = trace_analysis_loop(points, pwds, trace_idx)
#pairwise_trace_idx = list(itertools.combinations(traces["traceId"].unique(),2))
#pairwise_pwd_idx = list(itertools.combinations(range(pwds.shape[0]),2))
#res = Parallel(n_jobs=-2)(delayed(single_trace_analysis)
# (traces, pwds, idx1, idx2, idx_p1, idx_p2) for
# ((idx1, idx2), (idx_p1, idx_p2)) in
# zip(pairwise_trace_idx,pairwise_pwd_idx))
columns=['idx1', 'idx2', 'aligned_mse', 'aligned_pcc', 'pwd_mse', 'pwd_pcc', 'pwd_invsq_mse', 'pwd_invsq_pcc']
output=pd.DataFrame(res,columns=columns)
output[['idx1', 'idx2']] = output[['idx1', 'idx2']].astype(int)
return output
@njit
def trace_analysis_loop(points, pwds, trace_idx):
res = []
idx = list(range(len(points)))
for i in idx[:-1]:
a = points[i]
d_1 = pwds[i]
for j in idx[i+1:]:
b = points[j]
d_2 = pwds[j]
out = list(single_trace_analysis(a,b,d_1,d_2))
res.append([trace_idx[i], trace_idx[j]] + out)
return res
def compare_trace_analysis(traces1, traces2, pwds1, pwds2):
points1 = points_from_traces(traces1)
points2 = points_from_traces(traces2)
trace_idx1 = list(traces1.traceId.unique())
trace_idx2 = list(traces2.traceId.unique())
idx1 = list(range(len(points1)))
idx2 = list(range(len(points2)))
res = []
for i in idx1:
a = points1[i]
d_1 = pwds1[i]
for j in idx2:
b = points2[j]
d_2 = pwds2[j]
out = list(single_trace_analysis(a,b,d_1,d_2))
res.append([trace_idx1[i], trace_idx2[j]] + out)
columns=['idx1', 'idx2', 'aligned_mse', 'aligned_pcc', 'pwd_mse', 'pwd_pcc']
output=pd.DataFrame(res,columns=columns)
return output
@njit
def single_trace_analysis(a,b,d_1,d_2):
'''
Perform pairwise analysis of two single traces according to MSE and PCC metrics
of their aligned points and their distance matrices.
Parameters
----------
traces : pd DataFrame with trace data.
pwds : 3D np array from pwd_calc.
idx1: int, trace_id of first trace
idx2: int, trace_id of second trace
idx_p1: int, index of first trace in pwd matrix
idx_p2: int, index of second trace in pwd matrix
Returns
-------
output : The input trace_ids, and the similiarity metrics of the registered traces.
'''
#Get points by their trace indices.
a, b = match_two_pointsets(a, b)
if a.shape[0] < 3:
return 1000, 0, 1000, 0, 1000, 0
#Center the pointsand rescale to avoid issues of large numbers for PCC calculation.
a = a-numba_mean_axis0(a)
b = b-numba_mean_axis0(b)
#Align the point sets
b_reg = rigid_transform_3D(b, a, prematch=True)
#Calculate distances between the point sets
aligned_mse = euclidean(a, b_reg)
aligned_pcc = 1-mat_corr_pcc(a, b_reg)
#Calculate distances between the point distance matrices
pwd_mse = euclidean(d_1,d_2)
#rescale data to avoid issues of large numbers in PCC calculation
pwd_pcc = 1-mat_corr_pcc(d_1/1000,d_2/1000)
pwd_invsq_mse = euclidean(np.triu(1/(d_1**2),2), np.triu(1/(d_2**2),2)) * 10e6
pwd_invsq_pcc = 1-mat_corr_pcc(np.triu(1/(d_1**2)/1000,2), np.triu(1/(d_2**2)/1000,2))
return aligned_mse, aligned_pcc, pwd_mse, pwd_pcc, pwd_invsq_mse, pwd_invsq_pcc
@njit
def align_two_traces(a,b):
a = a-numba_mean_axis0(a)
b = b-numba_mean_axis0(b)
b_reg = rigid_transform_3D(b, a, prematch=True)
return a, b_reg
def pwd_clustering(traces, metric='pcc', embedding='umap', clust_method='kmeans_emb', n_clusters = 3, diagonal = 2, extra_column = None, traces_rw=None):
from sklearn.manifold import MDS, TSNE, SpectralEmbedding
from sklearn.mixture import GaussianMixture, BayesianGaussianMixture
from sklearn import cluster
from sklearn.impute import SimpleImputer
import umap
if extra_column is not None:
extra_data = traces.groupby("traceId")[extra_column].max().to_numpy()
#print(extra_data)
points = points_from_traces_nan(traces, trace_ids = -1)
if metric == 'pcc':
features = np.stack([pdist(arr) for arr in points])
dist_func = pcc_dist
elif metric == 'contact_dist':
ind = np.triu_indices(points[0].shape[0], k=diagonal)
features = np.stack([np.ravel(cdist(arr, arr)[ind]) for arr in points])
dist_func = contact_dist
elif metric == 'pcc_sq':
ind = np.triu_indices(points[0].shape[0], k=diagonal)
features = np.stack([np.ravel(cdist(arr, arr)[ind]) for arr in points])
features = 1000/features**2
dist_func = pcc_dist
#features = features/np.max(features)
elif metric == 'pcc_rw':
ind = np.triu_indices(points[0].shape[0], k=diagonal)
features = np.stack([np.ravel(cdist(arr, arr)[ind]) for arr in points])
dist_rw = np.ravel(np.nanmean(pwd_calc(traces_rw), axis=0)[ind])
features = 1/(features**2/dist_rw)
dist_func = pcc_dist
elif metric == 'tda':
import gudhi as gd
import gudhi.representations
all_points = points_from_traces(traces)
res = []
for points in all_points:
qc = points[:,3] == 1
points = points[qc,:3]
acX = gd.AlphaComplex(points=points).create_simplex_tree()
dgmX = acX.persistence()
LS = gd.representations.Landscape(resolution=100)
L = LS.fit_transform([acX.persistence_intervals_in_dimension(1)])
res.append(L[0])
features = np.stack(res)
dist_func = 'euclidean'
features = SimpleImputer(missing_values=np.nan, strategy='constant', fill_value=0).fit_transform(features)
if embedding == 'mds':
emb = MDS(n_components=2, dissimilarity='euclidean')
pos = emb.fit_transform(features)
elif embedding == 'tsne':
emb = TSNE(metric = dist_func, learning_rate = 200, perplexity = 50, square_distances = True, init = 'random', n_jobs=-2)
pos = emb.fit_transform(features)
elif embedding == 'umap':
emb = umap.UMAP(metric = dist_func, n_neighbors=20, min_dist=0)
pos = emb.fit_transform(features)
elif embedding == 'umap_train':
def str_to_num(s):
return len(s)
extra_data = np.array(list(map(str_to_num, extra_data)))
emb = umap.UMAP(metric = dist_func, n_neighbors=30, min_dist=0)
pos = emb.fit_transform(features, y=extra_data)
elif embedding == 'SE':
emb = SpectralEmbedding(n_components=2)
pos = emb.fit_transform(features)
if clust_method == 'wards':
model = cluster.AgglomerativeClustering(n_clusters=n_clusters)
elif clust_method == 'affinity':
model = cluster.AffinityPropagation(damping=0.8, preference=-400)
elif clust_method == 'affinity_emb':
model = cluster.AffinityPropagation(damping=0.8, preference=-400)
features = pos
elif clust_method == 'kmeans':
model = cluster.KMeans(n_clusters=n_clusters)
elif clust_method == 'kmeans_emb':
model = cluster.KMeans(n_clusters=n_clusters)
features = pos
elif clust_method == 'dbscan_emb':
model = cluster.DBSCAN(eps=0.2, min_samples=10)
features = pos
elif clust_method == 'optics_emb':
model = cluster.OPTICS(metric='euclidean')
features = pos
elif clust_method == 'meanshift_emb':
model = cluster.MeanShift(bandwidth=0.8)
features = pos
elif clust_method == 'gmm_emb':
model = GaussianMixture(n_components=n_clusters, init_params='kmeans')
features = pos
elif clust_method == 'bayes_gmm_emb':
model = BayesianGaussianMixture(n_components=n_clusters)
features = pos
trace_ids = list(traces.traceId.unique())
clusters = model.fit_predict(features)
data = np.array([trace_ids, clusters, pos[:,0], pos[:,1]]).T
print(data.shape)
res = pd.DataFrame(data, columns=["traceId", 'cluster', 'pos_x', 'pos_y'])
res["traceId"] = res["traceId"].astype(int)
res['cluster'] = res['cluster'].astype(int)
if extra_column is not None:
res[extra_column] = extra_data
return res
def trace_clustering(pairs, metric='pwd_pcc', dendro_method='single', color_threshold=None):
'''
Calculates clusters nased on similarity metrics of
pairwise analysis of traces. Also plots dendrogram of clusters.
Parameters
----------
paired : Paired analysis DataFrame, output of trace_analysis
metric : One of the similarity metrics from trace_analysis:
- 'aligned_mse'
- 'aligned_pcc'
- 'pwd_mse'
- 'pwd_pcc'
method : see scipy.cluster.hierarchy.linkage documentation
Returns
-------
cluster_df : Dataframe with results of
hierarchial clustering (see scipy.cluster.hierarchy.single docs)
'''
from scipy.cluster.hierarchy import set_link_color_palette
cmap = ['tab:blue', 'tab:orange', 'tab:green', 'tab:red', 'tab:purple', 'tab:brown', 'tab:pink']#px.colors.qualitative.Plotly
#cmap=cmap[5:]
set_link_color_palette(cmap)
labels=list(np.unique(np.concatenate((pairs['idx1'],pairs['idx2']))))
Z=linkage(pairs[metric], method=dendro_method)
fig1 = plt.figure(figsize=(20,20))
dendro=dendrogram(Z, labels=labels, color_threshold=color_threshold, leaf_font_size=9)
if color_threshold is None:
color_threshold = 0.7*max(Z[:,2])
clusters=fcluster(Z, color_threshold, criterion='distance')
cluster_df=pd.DataFrame([labels, clusters]).T
cluster_df.columns=["traceId", 'dendro']
return cluster_df
def further_trace_clustering(pairs, cluster_df, metric, n_clusters=5):
'''Perform a variety of sklearn clustering metrics on the data.
Args:
pairs (DataFrame): Distance matrix of trace pairs
cluster_df (DataFram): Existing clustering dataframe form dendrogram clustering
metric (str): Which metric to use as distance, typically aligned_pcc
n_clusters (int, optional): Number of predefined clusters. Defaults to 5.
Returns:
cluster_df (DataFrame): Updated clustering dataframe with additional clustering of all traces
pos (ndarray): Positions of traces in 2D coordinates generated by MDS.
'''
from sklearn.manifold import MDS, TSNE
from sklearn import cluster
import umap
#import hdbscan
distances = squareform(pairs[metric])
embedding_mds = MDS(n_components=2, dissimilarity='precomputed')
embedding_tsne = TSNE(metric = 'precomputed', learning_rate = 200, perplexity = 50, square_distances = True, init = 'random', n_jobs=-2)
emb_umap = umap.UMAP(metric = 'precomputed', n_neighbors=10, min_dist=0)
pos_mds = embedding_mds.fit_transform(distances)
pos_tsne = embedding_tsne.fit_transform(distances)
pos_umap = emb_umap.fit_transform(distances)
affinity_propagation = cluster.AffinityPropagation(damping=0.9, preference=-400)
aff = affinity_propagation.fit_predict(distances)
cluster_df['affinity'] = aff
aff_umap = affinity_propagation.fit_predict(pos_umap)
cluster_df['affinity_umap'] = aff_umap
spectral = cluster.SpectralClustering(n_clusters=n_clusters, eigen_solver='arpack', affinity="precomputed")
spec = spectral.fit_predict(1-distances)
cluster_df['spectral'] = spec
kmeans = cluster.KMeans(n_clusters=n_clusters)
km = kmeans.fit_predict(distances)
cluster_df['kmeans'] = km
km_umap = kmeans.fit_predict(pos_umap)
cluster_df['kmeans_umap'] = km_umap
#hdbscan_clusterer = hdbscan.HDBSCAN(metric='precomputed')
#hdb = hdbscan_clusterer.fit_predict(distances)
#cluster_df['hdb'] = hdb
#hdbscan_clusterer = hdbscan.HDBSCAN(metric='euclidean')
#hdb_umap = hdbscan_clusterer.fit_predict(pos_umap)
#cluster_df['hdb_umap'] = hdb_umap
#ward_model = cluster.AgglomerativeClustering(n_clusters=n_clusters)
#ward = ward_model.fit_predict(distances)
#cluster_df['ward'] = ward
return cluster_df, pos_mds, pos_tsne, pos_umap
def cluster_similarity(traces, cluster_df, method='cluster', metric='aligned_pcc'):
'''Function to compare similarity of traces within and between clusters.
#TODO: Hardcoded to aligned_pcc metric, fix this
Args:
traces (DataFrame): Trace data
cluster_df (DataFrame): Cluster assignment of traces
metric (str, optional): Which cluster method from the clustering assigment to use. Defaults to 'cluster'.
Returns:
(list): List with the mean values of the between and within cluster similarities.
'''
clust_ids = sorted(cluster_df[method].unique())
clust_pairs = list(itertools.combinations(clust_ids,2))
res_combo = []
for i, j in clust_pairs:
clust1 = sorted(cluster_df.query('{0} == {1}'.format(method, i)).traceId.unique())
clust2 = sorted(cluster_df.query('{0} == {1}'.format(method, j)).traceId.unique())
traces1 = traces[traces["traceId"].isin(clust1)]
traces2 = traces[traces["traceId"].isin(clust2)]
pwds1 = pwd_calc(traces1)
pwds2 = pwd_calc(traces2)
pairs_1_2 = compare_trace_analysis(traces1, traces2, pwds1, pwds2)
res_combo.append(pairs_1_2[metric].mean())
res_single = []
for i in clust_ids:
clust_i = sorted(cluster_df.query('{0} == {1}'.format(method, i)).traceId.unique())
traces_i = traces[traces["traceId"].isin(clust_i)]
pwds_i = pwd_calc(traces_i)
pairs_i = trace_analysis(traces_i, pwds_i)
res_single.append(pairs_i[metric].mean())
return [np.mean(res_combo), np.mean(res_single)]
def run_gpa_all_clusters(traces, cluster_df, metric='dendro', min_cluster = 1):
'''
Running function to perform GPA analysis on all clusters identified in trace_clustering()
with number of members above min_cluster.
'''
#Find unique cluster IDs from clustering table.
cluster_ids=set(cluster_df[metric])
#Generate list of lists of all cluster members over min_cluster length.
all_cluster_members = []
for cluster_id in cluster_ids:
cluster_members = cluster_df[cluster_df[metric]==cluster_id]["traceId"].values
if len(cluster_members)>=min_cluster:
all_cluster_members.append(cluster_members)
print(f'Cluster ID {cluster_id}, members: {cluster_members}.')
#Perform GPA analysis on each of the clusters seperately.
all_mean_points = [general_procrustes_analysis(traces, cluster_members)[1]
for cluster_members in all_cluster_members]
#Choose random cluster mean as template for alignment.
template = all_mean_points[np.random.randint(0,len(all_cluster_members))]
#Align all cluster means to template.
aligned_mean_points = [rigid_transform_3D(mean_points, template) for
mean_points in all_mean_points]
#Readd the template to the output.
#aligned_mean_points += [template]
return aligned_mean_points
def general_procrustes_analysis(traces, trace_ids='all', crit=0.01, template_points = 6):
'''
General procrustes analysis is performed as described in e.g.
https://en.wikipedia.org/wiki/Generalized_Procrustes_analysis
Runs until change in procrustes distance is less than crit.
Returns all the aligned traces, the mean trace and the std of all the aligned traces.
'''
if isinstance(trace_ids, str):
trace_ids = list(traces["traceId"].unique())
elif isinstance(trace_ids, list):
pass
else:
trace_ids=list(trace_ids.astype(int))
# Make list of all points of selected traces
all_points = points_from_traces(traces, trace_ids)
# Select a random template for initial loop
#np.random.seed(1)
while True:
t_idx = np.random.randint(0,len(all_points))
template = all_points[t_idx]
if np.sum(template[:,3]) > template_points:
#print('Template with more than 6 points found.')
break
template = center_points_qc(template)
#The initial distance before alignment.
prev_dist = np.sum([procrustes_dist(template, points) for
points in all_points])
#print('Initial distance: ', prev_dist)
#print('Number of traces: ', len(all_points))
#Run the first alignment step:
all_points, points_mean, dist = general_procrustes_loop(all_points, template)
#Run the remaining alignment steps until crit is reached:
n_cycles = 0
while np.abs(prev_dist-dist) > crit:
prev_dist = dist
all_points, points_mean, dist = general_procrustes_loop(all_points, points_mean)
n_cycles += 1
#print(f'GPA converged after {n_cycles} cycles with distance {dist}.')
#Calculate standard deviation of all points:
points_std = np.nanstd(np.stack(all_points), axis = 0)
return all_points, points_mean, points_std
def piecewise_gpa(traces, trace_ids='all', crit=0.01, segment_length = 5, overlap = 3):
'''
General procrustes analysis is performed as described in e.g.
https://en.wikipedia.org/wiki/Generalized_Procrustes_analysis
Runs until change in procrustes distance is less than crit.
Returns all the aligned traces, the mean trace and the std of all the aligned traces.
'''
if trace_ids == 'all':
trace_ids = list(traces["traceId"].unique())
else:
trace_ids=list(trace_ids.astype(int))
# Make list of all points of selected traces
hybs = np.sort(traces.query('QC == 1').hyb.unique())
segments = [hybs[i*(segment_length-overlap):i*(segment_length-overlap)+segment_length] for i in range(0,len(hybs)//(segment_length-overlap))]
segments = [s for s in segments if len(s) == segment_length]
print(segments)
aligned_segments = []
for seg in tqdm.tqdm(segments):
a = seg[0]
b = seg[-1]
traces_seg = traces.query('hyb >= @a & hyb <= @b')
aligned_segments.append(general_procrustes_analysis(traces_seg, trace_ids='all', crit=0.01, template_points = segment_length-1)[1])
full_trace = np.zeros((len(hybs), 4))
full_trace[0:segment_length, :] = aligned_segments[0]
for i in np.arange(0,len(aligned_segments)-1):
print(i)
prev_seg = full_trace[i*(segment_length-overlap):i*(segment_length-overlap) + segment_length].copy()
prev_seg[0:overlap-1,3] = 0
next_seg = aligned_segments[i+1].copy()
next_seg[-overlap+1:,3] = 0
reg_segment = rigid_transform_3D(next_seg,prev_seg,prematch=False)
new_avg = np.mean([prev_seg[-overlap:,:],reg_segment[:overlap,:]], axis=0)
next_seg[:overlap,:] = new_avg
next_seg[:,3] = 1
full_trace[(i+1)*(segment_length-overlap):(i+1)*(segment_length-overlap)+segment_length, :] = next_seg
return full_trace
def general_procrustes_loop(all_points, template):
'''
A single cycle in the general procrustes analysis.
Returns the points in all_points aligned to template,
the mean points and the procrustes distance to the mean.
'''
# Align all point sets to mean template
all_points_aligned = [rigid_transform_3D(offset, template) for
offset in all_points]
all_points_aligned = [points for points in all_points_aligned if points.shape[0] >3]#Ensure at least 3 points in traces.
#Set values that do not pass QC to nan in new list.
all_points_aligned_qc=[]
for points in all_points_aligned:
points[points[:,3] == 0, 0:3]=np.nan
all_points_aligned_qc.append(points)
#Calculate mean points from QC list ignoring nans.:
#points_mean = numba_trimmean_axis0(np.stack(all_points_aligned_qc), proportiontocut=0.1)
points_mean = np.nanmean(np.stack(all_points_aligned_qc), axis=0)
#points_mean = np.nanmedian(np.stack(all_points_aligned_qc), axis=0)
#The "QC" for the mean is 1 if at least one element has a QC=1
points_mean[:,3]=np.ceil(points_mean[:,3])
# Calculate distance to mean:
dist = np.sum([procrustes_dist(points_mean, points) for
points in all_points_aligned])
return all_points_aligned, points_mean, dist
@njit(error_model="numpy")
def procrustes_dist(a, b):
'''
Procrustes distance (identical to RMSD) between two point sets.
Matches them before calculation.
'''
#print(a,b)
a, b = match_two_pointsets(a, b)
#print(a,b)
if a.shape[0] == 0:
return 1e6
else:
dist = np.sqrt(np.mean((a-b)**2))
return dist
@njit
def procrustes_dist_corr(a, b):
'''
Correlation (PCC) between two point sets.
Matches them before calculation.
'''
a, b = match_two_pointsets(a, b)
dist = 1-mat_corr_pcc(a,b)
return dist
@njit
def numba_mean_axis0(arr):
'''Helper function due to lack of numba support for axis arguments.
'''
return np.array([np.mean(arr[:,i]) for i in range(arr.shape[1])])
@njit
def numba_trimmean_axis0(arr, proportiontocut=0.1):
'''Helper function for a trimmed mean due to lack of numba support for a trimmed mean function.
Args:
arr (ndarray): Array to calculate
proportiontocut (float, optional): [description]. Defaults to 0.1.
Returns:
[type]: [description]
'''
N = arr.shape[1]
D = arr.shape[2]
res = []
for i in range(N):
for j in range(D):
a = arr[:,i,j]
a = np.sort(a[~np.isnan(a)])
low = np.round(a.size*proportiontocut)
high = a.size - low
res.append(np.mean(a[low:high]))
res = np.array(res).reshape(N,D)
return res
@njit
def rigid_transform_3D(A_orig, B_orig, prematch = False):
'''
Calculates rigid transformation of two 3d points sets based on:
Least-squares fitting of two 3-D point sets. IEEE T Pattern Anal 1987
DOI: 10.1109/TPAMI.1987.4767965
Finds the optimal (lest squares) of B = RA + t, so mapping of A onto B.
Modified from http://nghiaho.com/?page_id=671
Only uses points present in both traces for alignment.
Parameters
----------
A, B : Nx4 (ZYX + QC) numpy ndarrays.
Returns
-------
Coordinates of registered and transformed A
'''
#Ensure matching points
if not prematch:
A, B = match_two_pointsets(A_orig, B_orig)
else:
A, B = A_orig, B_orig
if A.shape[0] == 0: #No matching points.
return A
# Subtract mean
# Workaround for lack of "axis" argument support in numba:
Ac = numba_mean_axis0(A)
Bc = numba_mean_axis0(B)
#Ac = np.mean(A, axis=0)
#Bc = np.mean(B, axis=0)
Am = A - Ac
Bm = B - Ac
# Calculate covariance matrix
H = Am.T @ Bm
# Find optimal rotation by SVD of the covariance matrix.
U, S, Vt = np.linalg.svd(H)
R = Vt.T @ U.T
# Handle case if the rotation matrix is reflected.
if np.linalg.det(R) < 0:
#print("det(R) < R, reflection detected!, correcting for it ...\n");
Vt[2,:] *= -1
R = Vt.T @ U.T
# calculate translation.
t = Bc - R @ Ac
#Transform the original vector with QC values
A_reg = np.copy(A_orig)
A_reg[:,:3] = (R @ A_orig[:,:3].T).T+t
return A_reg
@jit
def match_two_pointsets(A, B):
'''
Matches two point sets by their QC value to only return points
passing QC in both sets.
Parameters
----------
points_A, points_B : Nx4 (ZYX + QC) numpy ndarrays.
Returns
-------
points_A_matched, points_B_matched : Nx3 (ZYX) numpy ndarrays.
'''
match_idx = A[:,3] * B[:,3] != 0
A_match = A[match_idx,0:3]
B_match = B[match_idx,0:3]
return A_match, B_match
def points_from_traces(traces, trace_ids=-1):
'''
Helper function to extract point coordinates from trace dataframe.
Parameters
----------
traces : pd DataFrame with trace data.
trace_ids: single or multiple trace_ids to extract
Returns
-------
points_qc : list of Nx4 np array with trace coordinates and QC value.
'''
arr = traces[["traceId", "z", "y", "x", "QC"]].to_numpy()
if trace_ids == -1:
trace_ids, idx = np.unique(arr[:,0], return_index=True)
return np.split(arr, idx[1:], axis=0)
elif not isinstance(trace_ids, (list, tuple)):
trace_ids = [trace_ids]
return [arr[arr[:,0] == i][:,1:] for i in trace_ids]
def points_from_traces_qc_filt(traces, trace_ids=-1):
'''
Helper function to extract point coordinates from trace dataframe.
All points are returned, but points not passing QC during tracing
are not returned.
Parameters
----------
trace_df : pd DataFrame with trace data.
Returns
-------
points : Nx3 np array with trace coordinates.
'''
arr = traces[["traceId", "z", "y", "x", "QC"]].to_numpy()
qc_idx = arr[:,4] == 1
arr = arr[qc_idx,0:4]
if trace_ids == -1:
trace_ids, idx = np.unique(arr[:,0], return_index=True)
return np.split(arr, idx[1:], axis=0)
else:
if not isinstance(trace_ids, (list, tuple)):
trace_ids = [trace_ids]
return [arr[arr[:,0] == id][:,1:4] for id in trace_ids]
def points_from_traces_nan(traces, trace_ids=-1):
'''
Helper function to extract point coordinates from trace dataframe.
All points are returned, but points not passing QC during tracing
are returned as NaN.
Parameters
----------
trace_df : pd DataFrame with trace data.
Returns
-------
points : Nx3 np array with trace coordinates, NaN row returned if point did not pass QC.
'''
arr = traces[["traceId", "z", "y", "x", "QC"]].to_numpy()
qc_idx = arr[:,4] == 1
arr[~qc_idx,1:4] = np.nan
if trace_ids == -1:
trace_ids, idx = np.unique(arr[:,0], return_index=True)
return np.split(arr[:,1:4], idx[1:], axis=0)
elif not isinstance(trace_ids, (list, tuple)):
trace_ids = [trace_ids]
return [arr[arr[:,0] == id][:,1:4] for id in trace_ids]
def center_points_qc(a):
qc = a[:,3] # QC of points
ac = a[:,:3]-np.mean(a[qc>0,:3], axis=0) #Subtract mean of QC==1 points
ac = ac + np.abs(np.min(ac[qc>0,:3], axis=0)) #Shift so all values positive
qc = qc[:,np.newaxis] #Reshape QC to reappend
return np.append(ac,qc,axis=1)
def spline_interp(points, n_points=100):
'''
Performs cubic B-spline interpolation on point coordinates.
Parameters
----------
points : n_points X ndim list of point coordinates.
Returns
-------
fine : 100 X ndim nd array of interpolated points.
'''
tck, u = interpolate.splprep(points, s=0, k=2)
#knots = interpolate.splev(tck[0], tck)
u_fine = np.linspace(0,1,num=n_points)
fine = interpolate.splev(u_fine, tck)
return fine
@jit(nopython=True)
def euclidean(a, b):
return np.sqrt(np.nansum((a-b)**2))
@jit(nopython=True)
def mat_corr_rmse(a, b):
'''
Calculate mean squared error of two matrices, ignoring nans.
'''
rmse = np.sqrt(np.nanmean((a-b)**2))
return rmse
@jit(nopython=True)
def mat_corr_pcc(a,b):
'''
Calculate pearson's corr coef of two matrices, ignoring nans.
'''
a_m=np.nanmean(a)
b_m=np.nanmean(b)
pcc_num=np.nansum((a-a_m)*(b-b_m))
pcc_denom=np.sqrt(np.nansum((a-a_m)**2))*np.sqrt(np.nansum((b-b_m)**2))
pcc=np.divide(pcc_num,pcc_denom)
return pcc
@jit(nopython=True)
def pcc_dist(a,b):
'''
Calculate pearson's corr coef of two matrices, ignoring nans.
'''
ind = (a>0) & (b>0)
a = a[ind]
b = b[ind]
a_m=np.nanmean(a)
b_m=np.nanmean(b)
pcc_num=np.nansum((a-a_m)*(b-b_m))
pcc_denom=np.sqrt(np.nansum((a-a_m)**2))*np.sqrt(np.nansum((b-b_m)**2))
pcc=np.divide(pcc_num,pcc_denom)
return np.sqrt(1-pcc)
@jit(nopython=True)
def pcc_match(a,b):
'''
Calculate pearson's corr coef of two matrices, ignoring nans.
'''
ind = (a>0) & (b>0)
a = a[ind]
b = b[ind]
a_m=np.nanmean(a)
b_m=np.nanmean(b)
pcc_num=np.nansum((a-a_m)*(b-b_m))
pcc_denom=np.sqrt(np.nansum((a-a_m)**2))*np.sqrt(np.nansum((b-b_m)**2))
pcc=np.divide(pcc_num,pcc_denom)
return pcc
@jit(nopython=True)
def contact_dist(a,b):
'''
Calculate pearson's corr coef of two matrices, ignoring nans.
'''
ind = (a>0) & (b>0)
a = a[ind]
b = b[ind]
score = np.sum((a < 120) & (b < 120))
return 1-score/ind.size
@njit
def radius_of_gyration(point_set):
'''
Calculate ROG: R = sqrt(1/N * sum((r_k - r_mean)^2) for k points in structure.)
Source: https://en.wikipedia.org/wiki/Radius_of_gyration
'''
#Only include points passing QC:
#qc_idx = point_set[:,3] != 0
#point_set_qc = point_set[qc_idx, 0:3]
points_mean=numba_mean_axis0(point_set)
rog = np.sqrt(1/point_set.shape[0] * np.sum((point_set - points_mean)**2))
return rog
def elongation(point_set):
'''
Elongation in this case is defined as the ratio between the two primary
eigenvalues of the point set.
'''
#Only include points passing QC:
#qc_idx = point_set[:,3] != 0
#point_set_qc = point_set[qc_idx, 0:3]
#Center points to 0-mean
points_centered = point_set - np.mean(point_set, axis=0)
n, m = points_centered.shape
#Compute covariance matrix
cov = np.dot(points_centered.T, points_centered) / (n-1)
#Eigenvector decomposition of covariance matrix
eigen_vals, eigen_vecs = np.linalg.eig(cov)
#Elongation is the ratio of the secondary eigenvalue to primary eigenvalue
eigen_vals = np.sort(eigen_vals)[::-1]
#print('Eigenvalues are ', eigen_vals)
elongation = 1-(eigen_vals[1]/eigen_vals[0])
return elongation
def contour_length(point_set):
'''Calculate controur length of a trace.
Args:
point_set ([4d np array]): coordinate points vector with QC
Returns:
[float]: the contour length (sum of next point distances)
'''
#Only include points passing QC:
#qc_idx = point_set[:,3] != 0
#point_set_qc = point_set[qc_idx, 0:3]
dist = 0
for i in range(point_set.shape[0]-1):
dist += np.linalg.norm(point_set[i+1]-point_set[i])
return dist
def trace_metrics(traces, use_interp = False, diagonal = 0, contact_cutoff = 150, only_small_loops = False):
points_nan = points_from_traces_nan(traces)
points_qc = points_from_traces_qc_filt(traces)
#ind_u = np.triu_indices(points_nan[0].shape[0], k=2)
ind_l = np.tril_indices(points_nan[0].shape[0], k=0)
trace_ids = traces.traceId.unique()
metrics = []
loop_metrics = []
for i in tqdm.tqdm(range(len(points_nan))):
if use_interp:
point_set = np.column_stack(spline_interp([points_qc[i][:,0],points_qc[i][:,1],points_qc[i][:,2]]))
else:
point_set = points_nan[i]
dists = cdist(point_set, point_set)
ind_l = np.tril_indices(point_set.shape[0], k=diagonal)
dists_diag = dists.copy()
dists_diag[ind_l] = np.nan
contacts = dists_diag < contact_cutoff
contact_coords = np.argwhere(contacts)
#print(contact_coords)
stacked_loops = []
for c in contact_coords:
first_anchor = contact_coords[np.argwhere(contact_coords[:,0] == c[0]), :]
for a in first_anchor:
try:
second_anchor = np.all(contact_coords == [a[0,1],c[1]], axis=1)