Vincentqyw
fix: roma
8b973ee
raw
history blame
6.28 kB
# Standard libraries
import itertools
import numpy as np
# PyTorch
import torch
import torch.nn as nn
# Local
from . import JPEG_utils as utils
class y_dequantize(nn.Module):
"""Dequantize Y channel
Inputs:
image(tensor): batch x height x width
factor(float): compression factor
Outputs:
image(tensor): batch x height x width
"""
def __init__(self, factor=1):
super(y_dequantize, self).__init__()
self.y_table = utils.y_table
self.factor = factor
def forward(self, image):
return image * (self.y_table * self.factor)
class c_dequantize(nn.Module):
"""Dequantize CbCr channel
Inputs:
image(tensor): batch x height x width
factor(float): compression factor
Outputs:
image(tensor): batch x height x width
"""
def __init__(self, factor=1):
super(c_dequantize, self).__init__()
self.factor = factor
self.c_table = utils.c_table
def forward(self, image):
return image * (self.c_table * self.factor)
class idct_8x8(nn.Module):
"""Inverse discrete Cosine Transformation
Input:
dcp(tensor): batch x height x width
Output:
image(tensor): batch x height x width
"""
def __init__(self):
super(idct_8x8, self).__init__()
alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * u + 1) * x * np.pi / 16) * np.cos(
(2 * v + 1) * y * np.pi / 16
)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
def forward(self, image):
image = image * self.alpha
result = 0.25 * torch.tensordot(image, self.tensor, dims=2) + 128
result.view(image.shape)
return result
class block_merging(nn.Module):
"""Merge pathces into image
Inputs:
patches(tensor) batch x height*width/64, height x width
height(int)
width(int)
Output:
image(tensor): batch x height x width
"""
def __init__(self):
super(block_merging, self).__init__()
def forward(self, patches, height, width):
k = 8
batch_size = patches.shape[0]
# print(patches.shape) # (1,1024,8,8)
image_reshaped = patches.view(batch_size, height // k, width // k, k, k)
image_transposed = image_reshaped.permute(0, 1, 3, 2, 4)
return image_transposed.contiguous().view(batch_size, height, width)
class chroma_upsampling(nn.Module):
"""Upsample chroma layers
Input:
y(tensor): y channel image
cb(tensor): cb channel
cr(tensor): cr channel
Ouput:
image(tensor): batch x height x width x 3
"""
def __init__(self):
super(chroma_upsampling, self).__init__()
def forward(self, y, cb, cr):
def repeat(x, k=2):
height, width = x.shape[1:3]
x = x.unsqueeze(-1)
x = x.repeat(1, 1, k, k)
x = x.view(-1, height * k, width * k)
return x
cb = repeat(cb)
cr = repeat(cr)
return torch.cat([y.unsqueeze(3), cb.unsqueeze(3), cr.unsqueeze(3)], dim=3)
class ycbcr_to_rgb_jpeg(nn.Module):
"""Converts YCbCr image to RGB JPEG
Input:
image(tensor): batch x height x width x 3
Outpput:
result(tensor): batch x 3 x height x width
"""
def __init__(self):
super(ycbcr_to_rgb_jpeg, self).__init__()
matrix = np.array(
[[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]],
dtype=np.float32,
).T
self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
result = torch.tensordot(image + self.shift, self.matrix, dims=1)
# result = torch.from_numpy(result)
result.view(image.shape)
return result.permute(0, 3, 1, 2)
class decompress_jpeg(nn.Module):
"""Full JPEG decompression algortihm
Input:
compressed(dict(tensor)): batch x h*w/64 x 8 x 8
rounding(function): rounding function to use
factor(float): Compression factor
Ouput:
image(tensor): batch x 3 x height x width
"""
# def __init__(self, height, width, rounding=torch.round, factor=1):
def __init__(self, rounding=torch.round, factor=1):
super(decompress_jpeg, self).__init__()
self.c_dequantize = c_dequantize(factor=factor)
self.y_dequantize = y_dequantize(factor=factor)
self.idct = idct_8x8()
self.merging = block_merging()
# comment this line if no subsampling
self.chroma = chroma_upsampling()
self.colors = ycbcr_to_rgb_jpeg()
# self.height, self.width = height, width
def forward(self, y, cb, cr, height, width):
components = {"y": y, "cb": cb, "cr": cr}
# height = y.shape[0]
# width = y.shape[1]
self.height = height
self.width = width
for k in components.keys():
if k in ("cb", "cr"):
comp = self.c_dequantize(components[k])
# comment this line if no subsampling
height, width = int(self.height / 2), int(self.width / 2)
# height, width = int(self.height), int(self.width)
else:
comp = self.y_dequantize(components[k])
# comment this line if no subsampling
height, width = self.height, self.width
comp = self.idct(comp)
components[k] = self.merging(comp, height, width)
#
# comment this line if no subsampling
image = self.chroma(components["y"], components["cb"], components["cr"])
# image = torch.cat([components['y'].unsqueeze(3), components['cb'].unsqueeze(3), components['cr'].unsqueeze(3)], dim=3)
image = self.colors(image)
image = torch.min(
255 * torch.ones_like(image), torch.max(torch.zeros_like(image), image)
)
return image / 255