Vincentqyw
fix: roma
8b973ee
raw
history blame
1.99 kB
# Standard libraries
import numpy as np
# PyTorch
import torch
import torch.nn as nn
import math
y_table = np.array(
[
[16, 11, 10, 16, 24, 40, 51, 61],
[12, 12, 14, 19, 26, 58, 60, 55],
[14, 13, 16, 24, 40, 57, 69, 56],
[14, 17, 22, 29, 51, 87, 80, 62],
[18, 22, 37, 56, 68, 109, 103, 77],
[24, 35, 55, 64, 81, 104, 113, 92],
[49, 64, 78, 87, 103, 121, 120, 101],
[72, 92, 95, 98, 112, 100, 103, 99],
],
dtype=np.float32,
).T
y_table = nn.Parameter(torch.from_numpy(y_table))
#
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array(
[[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]
).T
c_table = nn.Parameter(torch.from_numpy(c_table))
def diff_round_back(x):
"""Differentiable rounding function
Input:
x(tensor)
Output:
x(tensor)
"""
return torch.round(x) + (x - torch.round(x)) ** 3
def diff_round(input_tensor):
test = 0
for n in range(1, 10):
test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor)
final_tensor = input_tensor - 1 / math.pi * test
return final_tensor
class Quant(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
input = torch.clamp(input, 0, 1)
output = (input * 255.0).round() / 255.0
return output
@staticmethod
def backward(ctx, grad_output):
return grad_output
class Quantization(nn.Module):
def __init__(self):
super(Quantization, self).__init__()
def forward(self, input):
return Quant.apply(input)
def quality_to_factor(quality):
"""Calculate factor corresponding to quality
Input:
quality(float): Quality for jpeg compression
Output:
factor(float): Compression factor
"""
if quality < 50:
quality = 5000.0 / quality
else:
quality = 200.0 - quality * 2
return quality / 100.0