import torch
import torch.nn as nn

from .JPEG_utils import diff_round, quality_to_factor, Quantization
from .compression import compress_jpeg
from .decompression import decompress_jpeg


class DiffJPEG(nn.Module):
    def __init__(self, differentiable=True, quality=75):
        """Initialize the DiffJPEG layer
        Inputs:
            height(int): Original image height
            width(int): Original image width
            differentiable(bool): If true uses custom differentiable
                rounding function, if false uses standrard torch.round
            quality(float): Quality factor for jpeg compression scheme.
        """
        super(DiffJPEG, self).__init__()
        if differentiable:
            rounding = diff_round
            # rounding = Quantization()
        else:
            rounding = torch.round
        factor = quality_to_factor(quality)
        self.compress = compress_jpeg(rounding=rounding, factor=factor)
        # self.decompress = decompress_jpeg(height, width, rounding=rounding,
        #                                   factor=factor)
        self.decompress = decompress_jpeg(rounding=rounding, factor=factor)

    def forward(self, x):
        """ """
        org_height = x.shape[2]
        org_width = x.shape[3]
        y, cb, cr = self.compress(x)

        recovered = self.decompress(y, cb, cr, org_height, org_width)
        return recovered