File size: 1,987 Bytes
10b4a5f
 
358ab8f
10b4a5f
 
 
 
 
 
358ab8f
 
 
 
 
 
 
 
 
 
 
 
10b4a5f
 
 
 
 
358ab8f
 
 
10b4a5f
 
 
 
358ab8f
10b4a5f
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
 
 
 
 
358ab8f
10b4a5f
 
 
 
 
 
358ab8f
10b4a5f
358ab8f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# Standard libraries
import numpy as np

# PyTorch
import torch
import torch.nn as nn
import math

y_table = np.array(
    [
        [16, 11, 10, 16, 24, 40, 51, 61],
        [12, 12, 14, 19, 26, 58, 60, 55],
        [14, 13, 16, 24, 40, 57, 69, 56],
        [14, 17, 22, 29, 51, 87, 80, 62],
        [18, 22, 37, 56, 68, 109, 103, 77],
        [24, 35, 55, 64, 81, 104, 113, 92],
        [49, 64, 78, 87, 103, 121, 120, 101],
        [72, 92, 95, 98, 112, 100, 103, 99],
    ],
    dtype=np.float32,
).T

y_table = nn.Parameter(torch.from_numpy(y_table))
#
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
c_table[:4, :4] = np.array(
    [[17, 18, 24, 47], [18, 21, 26, 66], [24, 26, 56, 99], [47, 66, 99, 99]]
).T
c_table = nn.Parameter(torch.from_numpy(c_table))


def diff_round_back(x):
    """Differentiable rounding function
    Input:
        x(tensor)
    Output:
        x(tensor)
    """
    return torch.round(x) + (x - torch.round(x)) ** 3


def diff_round(input_tensor):
    test = 0
    for n in range(1, 10):
        test += math.pow(-1, n + 1) / n * torch.sin(2 * math.pi * n * input_tensor)
    final_tensor = input_tensor - 1 / math.pi * test
    return final_tensor


class Quant(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input):
        input = torch.clamp(input, 0, 1)
        output = (input * 255.0).round() / 255.0
        return output

    @staticmethod
    def backward(ctx, grad_output):
        return grad_output


class Quantization(nn.Module):
    def __init__(self):
        super(Quantization, self).__init__()

    def forward(self, input):
        return Quant.apply(input)


def quality_to_factor(quality):
    """Calculate factor corresponding to quality
    Input:
        quality(float): Quality for jpeg compression
    Output:
        factor(float): Compression factor
    """
    if quality < 50:
        quality = 5000.0 / quality
    else:
        quality = 200.0 - quality * 2
    return quality / 100.0