datnguyentien204's picture
Upload 338 files
8e0b903 verified
raw
history blame
3.34 kB
from torch import nn
import torch
import torch.nn.functional as F
import numpy as np
class Flatten(nn.Module):
def forward(self, x):
x = x.view(x.size()[0], -1)
return x
class LSEPool2d(nn.Module):
def __init__(self, r=3):
super().__init__()
self.r =r
def forward(self, x):
s = x.size()[3] # x: bs*2048*7*7
r = self.r
x_max = F.adaptive_max_pool2d(x, 1) # x_max: bs*2048*1*1
p = ((1/r) * torch.log((1 / (s*s)) * torch.exp(r*(x - x_max)).sum(3).sum(2)))
x_max = x_max.view(x.size(0), -1) # bs*2048
return x_max+p
class WeightedBCEWithLogitsLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
w = self.get_weight(input, target)
return F.binary_cross_entropy_with_logits(input, target, w, reduction='mean')
def get_weight(self, input, target):
y = target.cpu().data.numpy()
y_hat = input.cpu().data.numpy()
P = np.count_nonzero(y == 1)
N = np.count_nonzero(y == 0)
beta_p = (P + N) / (P + 1) # may not contain disease
beta_n = (P + N) / N
w = np.empty(y.shape)
w[y==0] = beta_n
w[y==1] = beta_p
w = torch.FloatTensor(w).cuda()
return w
class SaveFeature:
features = None
def __init__(self, m):
self.hook = m.register_forward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.features = output
def remove(self):
self.hook.remove()
# class FocalLoss(WeightedBCELoss):
# def __init__(self, theta=2):
# super().__init__()
# self.theta = theta
# def forward(self, input, target):
# # pt = target*input + (1-target)*(1-input)
# # target *= (1-pt)**self.theta
# w = self.get_weight(input, target)
# return F.binary_cross_entropy_with_logits(input, target, w)
# class FocalLoss(nn.Module):
# def __init__(self, gamma=0, alpha=None, size_average=True):
# super(FocalLoss, self).__init__()
# self.gamma = gamma
# self.alpha = alpha
# if isinstance(alpha,(float,int,long)): self.alpha = torch.Tensor([alpha,1-alpha])
# if isinstance(alpha,list): self.alpha = torch.Tensor(alpha)
# self.size_average = size_average
# def forward(self, input, target):
# if input.dim()>2:
# input = input.view(input.size(0),input.size(1),-1) # N,C,H,W => N,C,H*W
# input = input.transpose(1,2) # N,C,H*W => N,H*W,C
# input = input.contiguous().view(-1,input.size(2)) # N,H*W,C => N*H*W,C
# target = target.view(-1,1)
# logpt = F.log_softmax(input)
# logpt = logpt.gather(1,target)
# logpt = logpt.view(-1)
# pt = Variable(logpt.data.exp())
# if self.alpha is not None:
# if self.alpha.type()!=input.data.type():
# self.alpha = self.alpha.type_as(input.data)
# at = self.alpha.gather(0,target.data.view(-1))
# logpt = logpt * Variable(at)
# loss = -1 * (1-pt)**self.gamma * logpt
# if self.size_average: return loss.mean()
# else: return loss.sum()