Spaces:
Sleeping
Sleeping
from dataclasses import dataclass | |
from pathlib import Path | |
from typing import Any | |
import torch | |
from PIL import Image | |
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import ( | |
MultiUpscaler, | |
UpscalerCheckpoints, | |
) | |
from esrgan_model import UpscalerESRGAN | |
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints): | |
esrgan: Path | |
class ESRGANUpscaler(MultiUpscaler): | |
def __init__( | |
self, | |
checkpoints: ESRGANUpscalerCheckpoints, | |
device: torch.device, | |
dtype: torch.dtype, | |
) -> None: | |
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype) | |
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype) | |
self.esrgan.to(device=device, dtype=dtype) | |
def to(self, device: torch.device, dtype: torch.dtype): | |
self.esrgan.to(device=device, dtype=dtype) | |
self.sd = self.sd.to(device=device, dtype=dtype) | |
self.device = device | |
self.dtype = dtype | |
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image: | |
image = self.esrgan.upscale_with_tiling(image) | |
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4) | |