Spaces:
Sleeping
Sleeping
File size: 1,266 Bytes
353f374 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import torch
from PIL import Image
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_upscaler import (
MultiUpscaler,
UpscalerCheckpoints,
)
from esrgan_model import UpscalerESRGAN
@dataclass(kw_only=True)
class ESRGANUpscalerCheckpoints(UpscalerCheckpoints):
esrgan: Path
class ESRGANUpscaler(MultiUpscaler):
def __init__(
self,
checkpoints: ESRGANUpscalerCheckpoints,
device: torch.device,
dtype: torch.dtype,
) -> None:
super().__init__(checkpoints=checkpoints, device=device, dtype=dtype)
self.esrgan = UpscalerESRGAN(checkpoints.esrgan, device=self.device, dtype=self.dtype)
self.esrgan.to(device=device, dtype=dtype)
def to(self, device: torch.device, dtype: torch.dtype):
self.esrgan.to(device=device, dtype=dtype)
self.sd = self.sd.to(device=device, dtype=dtype)
self.device = device
self.dtype = dtype
def pre_upscale(self, image: Image.Image, upscale_factor: float, **_: Any) -> Image.Image:
image = self.esrgan.upscale_with_tiling(image)
return super().pre_upscale(image=image, upscale_factor=upscale_factor / 4)
|