Spaces:
Running
on
T4
Running
on
T4
File size: 4,912 Bytes
561c629 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
# -*- coding: utf-8 -*-
import os
import torch
from torch import nn as nn
import torch.nn.functional as F
class PixelLoss(nn.Module):
def __init__(self) -> None:
super(PixelLoss, self).__init__()
self.criterion = torch.nn.L1Loss().cuda() # its default will take the mean of this batch
def forward(self, gen_hr, org_hr, batch_idx):
# Calculate general PSNR
pixel_loss = self.criterion(gen_hr, org_hr)
return pixel_loss
class L1_Charbonnier_loss(nn.Module):
"""L1 Charbonnierloss."""
def __init__(self):
super(L1_Charbonnier_loss, self).__init__()
self.eps = 1e-6 # already use square root
def forward(self, X, Y, batch_idx):
diff = torch.add(X, -Y)
error = torch.sqrt(diff * diff + self.eps)
loss = torch.mean(error)
return loss
"""
Created on Thu Dec 3 00:28:15 2020
@author: Yunpeng Li, Tianjin University
"""
class MS_SSIM_L1_LOSS(nn.Module):
# Have to use cuda, otherwise the speed is too slow.
def __init__(self, alpha,
gaussian_sigmas=[0.5, 1.0, 2.0, 4.0, 8.0],
data_range = 1.0,
K=(0.01, 0.4),
compensation=1.0,
cuda_dev=0,):
super(MS_SSIM_L1_LOSS, self).__init__()
self.DR = data_range
self.C1 = (K[0] * data_range) ** 2
self.C2 = (K[1] * data_range) ** 2
self.pad = int(2 * gaussian_sigmas[-1])
self.alpha = alpha
self.compensation=compensation
filter_size = int(4 * gaussian_sigmas[-1] + 1)
g_masks = torch.zeros((3*len(gaussian_sigmas), 1, filter_size, filter_size))
for idx, sigma in enumerate(gaussian_sigmas):
# r0,g0,b0,r1,g1,b1,...,rM,gM,bM
g_masks[3*idx+0, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
g_masks[3*idx+1, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
g_masks[3*idx+2, 0, :, :] = self._fspecial_gauss_2d(filter_size, sigma)
self.g_masks = g_masks.cuda(cuda_dev)
from torch.utils.tensorboard import SummaryWriter
self.writer = SummaryWriter()
def _fspecial_gauss_1d(self, size, sigma):
"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (size)
"""
coords = torch.arange(size).to(dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.reshape(-1)
def _fspecial_gauss_2d(self, size, sigma):
"""Create 2-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 2D kernel (size x size)
"""
gaussian_vec = self._fspecial_gauss_1d(size, sigma)
return torch.outer(gaussian_vec, gaussian_vec)
def forward(self, x, y, batch_idx):
'''
Args:
x (tensor): the input for a tensor
y (tensor): the input for another tensor
batch_idx (int): the iteration now
Returns:
combined_loss (torch): loss value of L1 with MS-SSIM loss
'''
# b, c, h, w = x.shape
mux = F.conv2d(x, self.g_masks, groups=3, padding=self.pad)
muy = F.conv2d(y, self.g_masks, groups=3, padding=self.pad)
mux2 = mux * mux
muy2 = muy * muy
muxy = mux * muy
sigmax2 = F.conv2d(x * x, self.g_masks, groups=3, padding=self.pad) - mux2
sigmay2 = F.conv2d(y * y, self.g_masks, groups=3, padding=self.pad) - muy2
sigmaxy = F.conv2d(x * y, self.g_masks, groups=3, padding=self.pad) - muxy
# l(j), cs(j) in MS-SSIM
l = (2 * muxy + self.C1) / (mux2 + muy2 + self.C1) # [B, 15, H, W]
cs = (2 * sigmaxy + self.C2) / (sigmax2 + sigmay2 + self.C2)
lM = l[:, -1, :, :] * l[:, -2, :, :] * l[:, -3, :, :]
PIcs = cs.prod(dim=1)
loss_ms_ssim = 1 - lM*PIcs # [B, H, W]
loss_l1 = F.l1_loss(x, y, reduction='none') # [B, 3, H, W]
# average l1 loss in 3 channels
gaussian_l1 = F.conv2d(loss_l1, self.g_masks.narrow(dim=0, start=-3, length=3),
groups=3, padding=self.pad).mean(1) # [B, H, W]
loss_mix = self.alpha * loss_ms_ssim + (1 - self.alpha) * gaussian_l1 / self.DR
loss_mix = self.compensation*loss_mix # Currently, we set compensation to 1.0
combined_loss = loss_mix.mean()
self.writer.add_scalar('Loss/ms_ssim_loss-iteration', loss_ms_ssim.mean(), batch_idx)
self.writer.add_scalar('Loss/l1_loss-iteration', gaussian_l1.mean(), batch_idx)
return combined_loss
|