Spaces:
Sleeping
Sleeping
File size: 1,236 Bytes
8e0b903 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import numpy as np
import cv2
from chexnet import ChexNet
from layers import SaveFeature
from constant import CLASS_NAMES
class HeatmapGenerator:
# def __init__(self, model_name='20180429-130928', mode=None):
def __init__(self, chexnet, mode=None):
self.chexnet = chexnet
self.sf = SaveFeature(chexnet.backbone)
self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0]
self.mapping = self.cam if mode == 'cam' else self.default
def cam(self, pred_y):
heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy()
return heatmap
# def default(self, pred_ys):
# return torch.max(torch.abs(self.sf.features), dim=1)[0]
def generate(self, image):
prob = self.chexnet.predict(image)
w, h = image.size
return self.from_prob(prob, w, h)
def from_prob(self, prob, w, h):
pred_y = np.argmax(prob)
heatmap = self.mapping(pred_y)
heatmap = heatmap - np.min(heatmap)
heatmap = heatmap / np.max(heatmap)
heatmap = cv2.resize(heatmap, (w, h))
return heatmap, CLASS_NAMES[pred_y]
|