-
Notifications
You must be signed in to change notification settings - Fork 26
/
Copy pathcreate_explainer.py
105 lines (93 loc) · 3.76 KB
/
create_explainer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from explainer import backprop as bp
from explainer import deeplift as df
from explainer import gradcam as gc
from explainer import patterns as pt
from explainer import ebp
from explainer import real_time as rt
def get_explainer(model, name):
methods = {
'vanilla_grad': bp.VanillaGradExplainer,
'grad_x_input': bp.GradxInputExplainer,
'saliency': bp.SaliencyExplainer,
'integrate_grad': bp.IntegrateGradExplainer,
'deconv': bp.DeconvExplainer,
'guided_backprop': bp.GuidedBackpropExplainer,
'deeplift_rescale': df.DeepLIFTRescaleExplainer,
'gradcam': gc.GradCAMExplainer,
'pattern_net': pt.PatternNetExplainer,
'pattern_lrp': pt.PatternLRPExplainer,
'excitation_backprop': ebp.ExcitationBackpropExplainer,
'contrastive_excitation_backprop': ebp.ContrastiveExcitationBackpropExplainer,
'real_time_saliency': rt.RealTimeSaliencyExplainer
}
if name == 'smooth_grad':
base_explainer = methods['vanilla_grad'](model)
explainer = bp.SmoothGradExplainer(base_explainer)
elif name.find('pattern') != -1:
explainer = methods[name](
model,
params_file='./weights/imagenet_224_vgg_16.npz',
pattern_file='./weights/imagenet_224_vgg_16.patterns.A_only.npz'
)
elif name == 'gradcam':
if model.__class__.__name__ == 'VGG':
explainer = methods[name](
model, target_layer_name_keys=['features', '30'] # pool5
)
elif model.__class__.__name__ == 'GoogleNet':
explainer = methods[name](
model, target_layer_name_keys=['pool5'], use_inp=True,
)
elif model.__class__.__name__ == 'ResNet':
explainer = methods[name](
model, target_layer_name_keys=['avgpool'], use_inp=True,
)
elif name == 'excitation_backprop':
if model.__class__.__name__ == 'VGG': # vgg16
explainer = methods[name](
model,
output_layer_keys=['features', '23'] # pool4
)
elif model.__class__.__name__ == 'ResNet': # resnet50
explainer = methods[name](
model,
output_layer_keys=['layer4', '1', 'conv1'] # res4a
)
elif model.__class__.__name__ == 'GoogleNet': # googlent
explainer = methods[name](
model,
output_layer_keys=['pool2']
)
elif name == 'contrastive_excitation_backprop':
if model.__class__.__name__ == 'VGG': # vgg16
explainer = methods[name](
model,
intermediate_layer_keys=['features', '30'], # pool5
output_layer_keys=['features', '23'], # pool4
final_linear_keys=['classifier', '6'] # fc8
)
elif model.__class__.__name__ == 'ResNet': # resnet50
explainer = methods[name](
model,
intermediate_layer_keys=['avgpool'],
output_layer_keys=['layer4', '1', 'conv1'], # res4a
final_linear_keys=['fc']
)
elif model.__class__.__name__ == 'GoogleNet':
explainer = methods[name](
model,
intermediate_layer_keys=['pool5'],
output_layer_keys=['pool2'],
final_linear_keys=['loss3.classifier']
)
elif name == 'real_time_saliency':
explainer = methods[name]('./weights/model-1.ckpt')
else:
explainer = methods[name](model)
return explainer
def get_heatmap(saliency):
saliency = saliency.squeeze()
if len(saliency.size()) == 2:
return saliency.abs().cpu().numpy()
else:
return saliency.abs().max(0)[0].cpu().numpy()