-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathfeature_importance_heatmap.py
105 lines (90 loc) · 4.31 KB
/
feature_importance_heatmap.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/python
#*-* coding: utf-8 *-*
###############################################################################
#
###############################################################################
import argparse
from sklearn.externals import joblib
def load_classifiers(files):
""" Load the input classifiers. """
return [joblib.load(infile) for infile in files]
def plot_heatmap(feat_imp, infile, secorder, min_val, max_val):
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
if secorder:
len_motif = (len(feat_imp) - 1) / 8
dico = {'Importance': list([feat_imp[0]]) * len_motif +
list(feat_imp[1:]), 'Feature': ['hit_score'] * len_motif +
['HelT'] * len_motif + ['ProT'] * len_motif + ['MGW'] *
len_motif + ['Roll'] * len_motif + ['HelT2'] * len_motif +
['ProT2'] * len_motif + ['MGW2'] * len_motif + ['Roll2'] *
len_motif, 'Position': range(len_motif) + range(len_motif) +
range(len_motif) + range(len_motif) + range(len_motif) +
range(len_motif) + range(len_motif) + range(len_motif) +
range(len_motif)}
else:
len_motif = (len(feat_imp) - 1) / 4
dico = {'Importance': list([feat_imp[0]]) * len_motif +
list(feat_imp[1:]), 'Feature': ['hit_score'] * len_motif +
['HelT'] * len_motif + ['ProT'] * len_motif + ['MGW'] *
len_motif + ['Roll'] * len_motif, 'Position': range(len_motif) +
range(len_motif) + range(len_motif) + range(len_motif) +
range(len_motif)}
map = pd.DataFrame.from_dict(dico)
map = map.pivot(index='Feature', columns='Position', values='Importance')
sns.heatmap(map, linewidth=.5, vmin=min_val, vmax=max_val)
plt.savefig('{0}.svg'.format(infile))
plt.clf()
def plot_average_heatmap(classifiers, output, secorder, min_val, max_val):
import pandas as pd
feat_imp = {}
for indx, clf in enumerate(classifiers):
feat_imp[indx] = pd.Series(clf.feature_importances_)
df = pd.DataFrame(feat_imp)
plot_heatmap(list(df.mean(1)), output, secorder, min_val, max_val)
def create_heatmap(argu):
import matplotlib
matplotlib.use('svg')
infiles = argu.classif_files
classifiers = load_classifiers(infiles)
for indx, clf in enumerate(classifiers):
plot_heatmap(clf.feature_importances_, infiles[indx], argu.secorder,
argu.min_val, argu.max_val)
if argu.output and len(classifiers) > 1:
plot_average_heatmap(classifiers, argu.output, argu.secorder,
argu.min_val, argu.max_val)
def arg_parsing():
""" Parse the arguments. """
descr = '''
Plot the heatmap corresponding to the feature importance associated to the
classifier(s) provided.
'''
parser = argparse.ArgumentParser(description=descr,
formatter_class=argparse.RawDescriptionHelpFormatter)
parser.add_argument('-c', '--classif', required=True, nargs='+',
dest='classif_files', action='store',
help='Classifier(s) to be used (.pkl file)')
help_str='Basename of the output averaged heatmap over multiple '
help_str += 'classifiers (.svg will be added)'
parser.add_argument('-a', '--average', required=False, type=str,
dest='output', action='store', default=None, help=help_str)
parser.add_argument('-2', '--second', required=False, dest='secorder',
action='store_true', default=False,
help='Plot classifier using 2nd order DNA shape features.')
parser.add_argument('-m', '--min', required=False, dest='min_val',
action='store', type=float, default=None,
help='Minimal value for the heat map range')
parser.add_argument('-M', '--max', required=False, dest='max_val',
action='store', type=float, default=None,
help='Maximal value for the heat map range')
parser.set_defaults(func=create_heatmap)
argu = parser.parse_args()
return argu
###############################################################################
# MAIN
###############################################################################
if __name__ == "__main__":
arguments = arg_parsing()
arguments.func(arguments)