|
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
import torch.nn as nn |
|
import math |
|
|
|
y_table = np.array( |
|
[ |
|
[16, 11, 10, 16, 24, 40, 51, 61], |
|
[12, 12, 14, 19, 26, 58, 60, 55], |
|
[14, 13, 16, 24, 40, 57, 69, 56], |
|
[14, 17, 22, 29, 51, 87, 80, 62], |
|
[18, 22, 37, 56, 68, 109, 103, 77], |
|
[24, 35, 55, 64, 81, 104, 113, 92], |
|
[49, 64, 78, 87, 103, 121, 120, 101], |
|
[72, 92, 95, 98, 112, 100, 103, 99], |
|
], |
|
dtype=np.float32, |
|
).T |
|
|
|
y_table = nn.Parameter(torch.from_numpy(y_table)) |
|
|
|
c_table = np.empty((8, 8), dtype=np.float32) |
|
c_table.fill(99) |
|
c_table[:4, :4] = np.array( |
|
[[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]] |
|
).T |
|
c_table = nn.Parameter(torch.from_numpy(c_table)) |
|
|
|
|
|
def diff_round_back(x): |
|
"""Differentiable rounding function |
|
Input: |
|
x(tensor) |
|
Output: |
|
x(tensor) |
|
""" |
|
return torch.round(x) + (x - torch.round(x)) ** 3 |
|
|
|
|
|
def diff_round(input_tensor): |
|
test = 0 |
|
for n in range(1, 10): |
|
test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor) |
|
final_tensor = input_tensor - 1 / math.pi * test |
|
return final_tensor |
|
|
|
|
|
class Quant(torch.autograd.Function): |
|
@staticmethod |
|
def forward(ctx, input): |
|
input = torch.clamp(input, 0, 1) |
|
output = (input * 255.0).round() / 255.0 |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output |
|
|
|
|
|
class Quantization(nn.Module): |
|
def __init__(self): |
|
super(Quantization, self).__init__() |
|
|
|
def forward(self, input): |
|
return Quant.apply(input) |
|
|
|
|
|
def quality_to_factor(quality): |
|
"""Calculate factor corresponding to quality |
|
Input: |
|
quality(float): Quality for jpeg compression |
|
Output: |
|
factor(float): Compression factor |
|
""" |
|
if quality < 50: |
|
quality = 5000.0 / quality |
|
else: |
|
quality = 200.0 - quality * 2 |
|
return quality / 100.0 |
|
|