datnguyentien204's picture
Upload 338 files
8e0b903 verified
import torch
import numpy as np
import cv2
from chexnet import ChexNet
from layers import SaveFeature
from constant import CLASS_NAMES
class HeatmapGenerator:
# def __init__(self, model_name='20180429-130928', mode=None):
def __init__(self, chexnet, mode=None):
self.chexnet = chexnet
self.sf = SaveFeature(chexnet.backbone)
self.weight = list(list(self.chexnet.head.children())[-1].parameters())[0]
self.mapping = self.cam if mode == 'cam' else self.default
def cam(self, pred_y):
heatmap = self.sf.features[0].permute(1, 2, 0).detach().numpy() @ self.weight[pred_y].detach().numpy()
return heatmap
# def default(self, pred_ys):
# return torch.max(torch.abs(self.sf.features), dim=1)[0]
def generate(self, image):
prob = self.chexnet.predict(image)
w, h = image.size
return self.from_prob(prob, w, h)
def from_prob(self, prob, w, h):
pred_y = np.argmax(prob)
heatmap = self.mapping(pred_y)
heatmap = heatmap - np.min(heatmap)
heatmap = heatmap / np.max(heatmap)
heatmap = cv2.resize(heatmap, (w, h))
return heatmap, CLASS_NAMES[pred_y]