datnguyentien204's picture
Upload 338 files
8e0b903 verified
raw
history blame
1.24 kB
import torch
import numpy as np
import cv2
from chexnet import ChexNet
from layers import SaveFeature
from constant import CLASS_NAMES
class HeatmapGenerator:
# def __init__(self, model_name='20180429-130928', mode=None):
def __init__(self, chexnet, mode=None):
self.chexnet = chexnet
self.sf = SaveFeature(chexnet.backbone)
self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0]
self.mapping = self.cam if mode == 'cam' else self.default
def cam(self, pred_y):
heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy()
return heatmap
# def default(self, pred_ys):
# return torch.max(torch.abs(self.sf.features), dim=1)[0]
def generate(self, image):
prob = self.chexnet.predict(image)
w, h = image.size
return self.from_prob(prob, w, h)
def from_prob(self, prob, w, h):
pred_y = np.argmax(prob)
heatmap = self.mapping(pred_y)
heatmap = heatmap - np.min(heatmap)
heatmap = heatmap / np.max(heatmap)
heatmap = cv2.resize(heatmap, (w, h))
return heatmap, CLASS_NAMES[pred_y]