File size: 1,236 Bytes
8e0b903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import numpy as np
import cv2
from chexnet import ChexNet
from layers import SaveFeature
from constant import CLASS_NAMES


class HeatmapGenerator:

    # def __init__(self, model_name='20180429-130928', mode=None):
    def __init__(self, chexnet, mode=None):
        self.chexnet = chexnet
        self.sf = SaveFeature(chexnet.backbone)
        self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0]
        self.mapping = self.cam if mode == 'cam' else self.default

    def cam(self, pred_y):
        heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy()
        return heatmap

    # def default(self, pred_ys):
    #     return torch.max(torch.abs(self.sf.features), dim=1)[0]

    def generate(self, image):
        prob = self.chexnet.predict(image)
        w, h = image.size
        return self.from_prob(prob, w, h)

    def from_prob(self, prob, w, h):
        pred_y = np.argmax(prob)
        heatmap = self.mapping(pred_y)

        heatmap = heatmap - np.min(heatmap)
        heatmap = heatmap / np.max(heatmap)
        heatmap = cv2.resize(heatmap, (w, h))

        return heatmap, CLASS_NAMES[pred_y]