File size: 5,265 Bytes
3f2b233
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import shap
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import numpy as np
import matplotlib.pyplot as plt

# Custom BasicBlock to avoid in-place operations
class CustomBasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1, downsample=None, groups=1, base_width=64, dilation=1, norm_layer=None):
        super(CustomBasicBlock, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = norm_layer(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1, bias=False)
        self.bn2 = norm_layer(planes)
        self.downsample = downsample
        self.stride = stride

    def forward(self, x):
        identity = x.clone()

        out = self.conv1(x)
        out = self.bn1(out)
        out = F.relu(out.clone(), inplace=False)

        out = self.conv2(out)
        out = self.bn2(out)

        if self.downsample is not None:
            identity = self.downsample(x.clone())

        out = out.clone() + identity  # Clone before addition to avoid in-place modification
        out = F.relu(out.clone(), inplace=False)

        return out

# Custom ResNet using CustomBasicBlock
class CustomResNet(nn.Module):
    def __init__(self, block, layers, num_classes=1000):
        super(CustomResNet, self).__init__()
        self.inplanes = 64
        self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=False)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.layer1 = self._make_layer(block, 64, layers[0])
        self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
        self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
        self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, block, planes, blocks, stride=1):
        norm_layer = nn.BatchNorm2d
        downsample = None
        if stride != 1 or self.inplanes != planes * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.inplanes, planes * block.expansion, kernel_size=1, stride=stride, bias=False),
                norm_layer(planes * block.expansion),
            )

        layers = []
        layers.append(block(self.inplanes, planes, stride, downsample, groups=1, base_width=64, dilation=1, norm_layer=norm_layer))
        self.inplanes = planes * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.inplanes, planes, groups=1, base_width=64, dilation=1, norm_layer=norm_layer))

        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x.clone())  # Clone to avoid in-place operation
        x = self.maxpool(x)

        x = self.layer1(x.clone())  # Clone to avoid in-place operation
        x = self.layer2(x.clone())  # Clone to avoid in-place operation
        x = self.layer3(x.clone())  # Clone to avoid in-place operation
        x = self.layer4(x.clone())  # Clone to avoid in-place operation

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x.clone())  # Clone to avoid in-place operation

        return x

# Initialize the custom model with pre-trained weights
model = CustomResNet(CustomBasicBlock, [2, 2, 2, 2])
state_dict = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1).state_dict()
model.load_state_dict(state_dict)
model.eval()

# Initialize SHAP explainer with the custom model
explainer = shap.DeepExplainer(model, torch.randn(1, 3, 224, 224))

# Generate SHAP values for an input image
sample_image = torch.randn(1, 3, 224, 224)
shap_values = explainer.shap_values(sample_image, check_additivity=False)

# Convert SHAP values and sample image to numpy for SHAP visualization
shap_values_class_0 = shap_values[0][0]  # Extract SHAP values for the first class
sample_image_np = sample_image.squeeze().permute(1, 2, 0).detach().numpy()

# Normalize sample image and SHAP values to range [0, 1] for visualization
sample_image_np = np.clip(sample_image_np, 0, 1)
shap_min, shap_max = shap_values_class_0.min(), shap_values_class_0.max()
shap_values_class_0 = (shap_values_class_0 - shap_min) / (shap_max - shap_min)

# Ensure both `sample_image_np` and `shap_values_class_0` are NumPy arrays with correct shapes for image_plot
sample_image_np = np.array([sample_image_np])  # Add batch dimension for SHAP
shap_values_class_0 = np.array([shap_values_class_0])  # Add batch dimension for SHAP

# Visualize SHAP values for the first class
shap.image_plot(shap_values_class_0, sample_image_np)