|
import torch |
|
import torch.nn as nn |
|
|
|
from .JPEG_utils import diff_round, quality_to_factor, Quantization |
|
from .compression import compress_jpeg |
|
from .decompression import decompress_jpeg |
|
|
|
|
|
class DiffJPEG(nn.Module): |
|
def __init__(self, differentiable=True, quality=75): |
|
"""Initialize the DiffJPEG layer |
|
Inputs: |
|
height(int): Original image height |
|
width(int): Original image width |
|
differentiable(bool): If true uses custom differentiable |
|
rounding function, if false uses standrard torch.round |
|
quality(float): Quality factor for jpeg compression scheme. |
|
""" |
|
super(DiffJPEG, self).__init__() |
|
if differentiable: |
|
rounding = diff_round |
|
|
|
else: |
|
rounding = torch.round |
|
factor = quality_to_factor(quality) |
|
self.compress = compress_jpeg(rounding=rounding, factor=factor) |
|
|
|
|
|
self.decompress = decompress_jpeg(rounding=rounding, factor=factor) |
|
|
|
def forward(self, x): |
|
""" """ |
|
org_height = x.shape[2] |
|
org_width = x.shape[3] |
|
y, cb, cr = self.compress(x) |
|
|
|
recovered = self.decompress(y, cb, cr, org_height, org_width) |
|
return recovered |
|
|