Spaces:
Sleeping
Sleeping
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch | |
| from layers import SaveFeature | |
| import pretrainedmodels | |
| from torchvision.models import resnet34, resnet50, resnet101, resnet152 | |
| from pathlib import Path | |
| from torchvision.models.resnet import conv3x3, BasicBlock, Bottleneck | |
| import skimage | |
| from scipy import ndimage | |
| import numpy as np | |
| import torchvision.transforms as transforms | |
| import cv2 | |
| from constant import IMAGENET_MEAN, IMAGENET_STD | |
| device="cuda" if torch.cuda.is_available() else "cpu" | |
| class UpBlock(nn.Module): | |
| expansion = 1 | |
| def __init__(self, inplanes, planes, expansion=1): | |
| super().__init__() | |
| inplanes = inplanes * expansion | |
| planes = planes * expansion | |
| self.upconv = nn.ConvTranspose2d(inplanes, planes, 2, 2, 0) | |
| self.bn1 = nn.BatchNorm2d(planes) | |
| self.relu = nn.ReLU(inplace=True) | |
| self.conv1 = conv3x3(inplanes, planes) | |
| self.bn2 = nn.BatchNorm2d(planes) | |
| def forward(self, u, x): | |
| up = self.relu(self.bn1(self.upconv(u))) | |
| out = torch.cat([x, up], dim=1) # cat along channel | |
| out = self.relu(self.bn2(self.conv1(out))) | |
| return out | |
| class UpLayer(nn.Module): | |
| def __init__(self, block, inplanes, planes, blocks): | |
| super().__init__() | |
| self.up = UpBlock(inplanes, planes, block.expansion) | |
| layers = [block(planes * block.expansion, planes) for _ in range(1, blocks)] | |
| self.conv = nn.Sequential(*layers) | |
| def forward(self, u, x): | |
| x = self.up(u, x) | |
| x = self.conv(x) | |
| return x | |
| from pathlib import Path | |
| class Unet(nn.Module): | |
| tfm = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD) | |
| ]) | |
| def __init__(self, trained=False, model_name=None): | |
| super().__init__() | |
| self.layers = [3, 4, 6] | |
| self.block = Bottleneck | |
| if trained: | |
| assert model_name is not None | |
| self.load_model(model_name) | |
| else: | |
| self.load_pretrained() | |
| def cut_model(self, model, cut): | |
| return list(model.children())[:cut] | |
| def load_model(self, model_name): | |
| resnet = resnet50(False) | |
| self.backbone = nn.Sequential(*self.cut_model(resnet, 8)) | |
| self.init_head() | |
| model_path = Path(__file__).parent / 'unet.h5' | |
| state_dict = torch.load(model_path, map_location=torch.device(device)) | |
| self.load_state_dict(state_dict) | |
| def load_pretrained(self, torch=False): | |
| if torch: | |
| resnet = resnet50(True) | |
| else: | |
| resnet = pretrainedmodels.__dict__['resnet50']() | |
| self.backbone = nn.Sequential(*self.cut_model(resnet, 8)) | |
| self.init_head() | |
| def init_head(self): | |
| self.sfs = [SaveFeature(self.backbone[i]) for i in [2, 4, 5, 6]] | |
| self.up_layer1 = UpLayer(self.block, 512, 256, self.layers[-1]) | |
| self.up_layer2 = UpLayer(self.block, 256, 128, self.layers[-2]) | |
| self.up_layer3 = UpLayer(self.block, 128, 64, self.layers[-3]) | |
| self.map = conv3x3(64 * self.block.expansion, 64) # 64e -> 64 | |
| self.conv = conv3x3(128, 64) | |
| self.bn_conv = nn.BatchNorm2d(64) | |
| self.up_conv = nn.ConvTranspose2d(64, 1, 2, 2, 0) | |
| self.bn_up = nn.BatchNorm2d(1) | |
| def forward(self, x): | |
| x = F.relu(self.backbone(x)) | |
| x = self.up_layer1(x, self.sfs[3].features) | |
| x = self.up_layer2(x, self.sfs[2].features) | |
| x = self.up_layer3(x, self.sfs[1].features) | |
| x = self.map(x) | |
| x = F.interpolate(x, scale_factor=2) | |
| x = torch.cat([self.sfs[0].features, x], dim=1) | |
| x = F.relu(self.bn_conv(self.conv(x))) | |
| x = F.relu(self.bn_up(self.up_conv(x))) | |
| return x | |
| def close(self): | |
| for sf in self.sfs: | |
| sf.remove() | |
| def segment(self, image): | |
| """ | |
| image: cropped CXR PIL Image (h, w, 3) | |
| """ | |
| kernel = np.ones((10, 10)) | |
| iw, ih = image.size | |
| image_tensor = self.tfm(image).unsqueeze(0).to(next(self.parameters()).device) | |
| with torch.no_grad(): | |
| py = torch.sigmoid(self(image_tensor)) | |
| py = (py[0].cpu() > 0.5).type(torch.FloatTensor) # 1, 256, 256 | |
| mask = py[0].numpy() | |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) | |
| mask = cv2.resize(mask, (iw, ih)) | |
| slice_y, slice_x = ndimage.find_objects(mask, 1)[0] | |
| h, w = slice_y.stop - slice_y.start, slice_x.stop - slice_x.start | |
| nw, nh = int(w / .875), int(h / .875) | |
| dw, dh = (nw - w) // 2, (nh - h) // 2 | |
| t = max(slice_y.start - dh, 0) | |
| l = max(slice_x.start - dw, 0) | |
| b = min(slice_y.stop + dh, ih) | |
| r = min(slice_x.stop + dw, iw) | |
| return (t, l, b, r), mask | |