Spaces:
Running
Running
File size: 3,418 Bytes
a660631 f521e88 a660631 7a1ec93 a660631 f521e88 a660631 7a1ec93 f521e88 a660631 7a1ec93 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 f521e88 a660631 7a1ec93 a660631 7a1ec93 a660631 7a1ec93 a660631 f521e88 a660631 f521e88 a660631 7a1ec93 a660631 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gc
import numpy as np
import PIL.Image
import torch
from controlnet_aux import (
CannyDetector,
ContentShuffleDetector,
HEDdetector,
LineartAnimeDetector,
LineartDetector,
MidasDetector,
MLSDdetector,
NormalBaeDetector,
OpenposeDetector,
PidiNetDetector,
)
from controlnet_aux.util import HWC3
from cv_utils import resize_image
from depth_estimator import DepthEstimator
from image_segmentor import ImageSegmentor, ImageSegmentorOneFormer
class Preprocessor:
MODEL_ID = "lllyasviel/Annotators"
def __init__(self):
self.model = None
self.models = {}
self.name = ""
def load(self, name: str) -> None:
if name == self.name:
return
if name in self.models:
self.name = name
self.model = self.models[name]
return
if name == "HED":
self.model = HEDdetector.from_pretrained(self.MODEL_ID)
elif name == "Midas":
self.model = MidasDetector.from_pretrained(self.MODEL_ID)
elif name == "MLSD":
self.model = MLSDdetector.from_pretrained(self.MODEL_ID)
elif name == "Openpose":
self.model = OpenposeDetector.from_pretrained(self.MODEL_ID)
elif name == "PidiNet":
self.model = PidiNetDetector.from_pretrained(self.MODEL_ID)
elif name == "NormalBae":
self.model = NormalBaeDetector.from_pretrained(self.MODEL_ID)
elif name == "Lineart":
self.model = LineartDetector.from_pretrained(self.MODEL_ID)
elif name == "LineartAnime":
self.model = LineartAnimeDetector.from_pretrained(self.MODEL_ID)
elif name == "Canny":
self.model = CannyDetector()
elif name == "ContentShuffle":
self.model = ContentShuffleDetector()
elif name == "DPT":
self.model = DepthEstimator()
elif name == "UPerNet":
self.model = ImageSegmentor()
elif name == "OneFormer":
self.model = ImageSegmentorOneFormer()
else:
raise ValueError
# if torch.cuda.is_available():
# torch.cuda.empty_cache()
# gc.collect()
self.name = name
self.models[name] = self.model
def __call__(self, image: PIL.Image.Image, **kwargs) -> PIL.Image.Image:
if self.name == "Canny":
if "detect_resolution" in kwargs:
detect_resolution = kwargs.pop("detect_resolution")
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
return PIL.Image.fromarray(image)
elif self.name == "Midas":
detect_resolution = kwargs.pop("detect_resolution", 512)
image_resolution = kwargs.pop("image_resolution", 512)
image = np.array(image)
image = HWC3(image)
image = resize_image(image, resolution=detect_resolution)
image = self.model(image, **kwargs)
if isinstance(image, tuple):
image = image[-1][...,::-1] # normal old
image = HWC3(image)
image = resize_image(image, resolution=image_resolution)
return PIL.Image.fromarray(image)
else:
return self.model(image, **kwargs)
|