Vincentqyw
fix: roma
358ab8f
raw
history blame
1.4 kB
import torch
import torch.nn as nn
from .JPEG_utils import diff_round, quality_to_factor, Quantization
from .compression import compress_jpeg
from .decompression import decompress_jpeg
class DiffJPEG(nn.Module):
def __init__(self, differentiable=True, quality=75):
"""Initialize the DiffJPEG layer
Inputs:
height(int): Original image height
width(int): Original image width
differentiable(bool): If true uses custom differentiable
rounding function, if false uses standrard torch.round
quality(float): Quality factor for jpeg compression scheme.
"""
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
# rounding = Quantization()
else:
rounding = torch.round
factor = quality_to_factor(quality)
self.compress = compress_jpeg(rounding=rounding, factor=factor)
# self.decompress = decompress_jpeg(height, width, rounding=rounding,
# factor=factor)
self.decompress = decompress_jpeg(rounding=rounding, factor=factor)
def forward(self, x):
""" """
org_height = x.shape[2]
org_width = x.shape[3]
y, cb, cr = self.compress(x)
recovered = self.decompress(y, cb, cr, org_height, org_width)
return recovered