datnguyentien204's picture
Upload 338 files
8e0b903 verified
raw
history blame
2.14 kB
import torch.nn as nn
import pretrainedmodels
from torchvision.models import densenet121
from layers import Flatten
import torch
import torchvision.transforms as transforms
from pathlib import Path
from constant import IMAGENET_MEAN, IMAGENET_STD
import os
import sys
script_dir = os.path.dirname(os.path.abspath(__file__))
yolov9 = os.path.join(script_dir, '..', 'chestXray14')
sys.path.append(yolov9)
class ChexNet(nn.Module):
tfm = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD)
])
def __init__(self, trained=False, model_name='20180525-222635'):
super().__init__()
# chexnet.parameters() is freezed except head
if trained:
self.load_model(model_name)
else:
self.load_pretrained()
def load_model(self, model_name):
self.backbone = densenet121(False).features
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(1024, 14)
)
path = Path('chestX-ray-14')
state_dict = torch.load('chexnet.h5')
self.load_state_dict(state_dict)
def load_pretrained(self, torch=False):
if torch:
self.backbone = densenet121(True).features
else:
self.backbone = pretrainedmodels.__dict__['densenet121']().features
self.head = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
Flatten(),
nn.Linear(1024, 14)
)
def forward(self, x):
return self.head(self.backbone(x))
def predict(self, image):
"""
input: PIL image (w, h, c)
output: prob np.array
"""
image_tensor = self.tfm(image).unsqueeze(0) # Add batch dimension
image_tensor = image_tensor.to(next(self.parameters()).device) # Move to the same device as the model
with torch.no_grad():
py = torch.sigmoid(self(image_tensor))
prob = py.cpu().numpy()[0]
return prob