File size: 1,404 Bytes
404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af 8b973ee 404d2af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import torch
import torch.nn as nn
from .JPEG_utils import diff_round, quality_to_factor, Quantization
from .compression import compress_jpeg
from .decompression import decompress_jpeg
class DiffJPEG(nn.Module):
def __init__(self, differentiable=True, quality=75):
"""Initialize the DiffJPEG layer
Inputs:
height(int): Original image height
width(int): Original image width
differentiable(bool): If true uses custom differentiable
rounding function, if false uses standrard torch.round
quality(float): Quality factor for jpeg compression scheme.
"""
super(DiffJPEG, self).__init__()
if differentiable:
rounding = diff_round
# rounding = Quantization()
else:
rounding = torch.round
factor = quality_to_factor(quality)
self.compress = compress_jpeg(rounding=rounding, factor=factor)
# self.decompress = decompress_jpeg(height, width, rounding=rounding,
# factor=factor)
self.decompress = decompress_jpeg(rounding=rounding, factor=factor)
def forward(self, x):
""" """
org_height = x.shape[2]
org_width = x.shape[3]
y, cb, cr = self.compress(x)
recovered = self.decompress(y, cb, cr, org_height, org_width)
return recovered
|