Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
@@ -18,6 +18,64 @@ from diffusers import (
|
|
18 |
T2IAdapter,
|
19 |
)
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
ADAPTER_NAMES = [
|
22 |
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
23 |
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
@@ -57,7 +115,7 @@ class LineartPreprocessor(Preprocessor):
|
|
57 |
return self.model.to(device)
|
58 |
|
59 |
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
60 |
-
return self.model(image, detect_resolution=
|
61 |
|
62 |
|
63 |
class MidasPreprocessor(Preprocessor):
|
@@ -273,6 +331,8 @@ class Model:
|
|
273 |
if apply_preprocess:
|
274 |
image = self.preprocessor(image)
|
275 |
|
|
|
|
|
276 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
277 |
out = self.pipe(
|
278 |
prompt=prompt,
|
|
|
18 |
T2IAdapter,
|
19 |
)
|
20 |
|
21 |
+
SD_XL_BASE_RATIOS = {
|
22 |
+
"0.5": (704, 1408),
|
23 |
+
"0.52": (704, 1344),
|
24 |
+
"0.57": (768, 1344),
|
25 |
+
"0.6": (768, 1280),
|
26 |
+
"0.68": (832, 1216),
|
27 |
+
"0.72": (832, 1152),
|
28 |
+
"0.78": (896, 1152),
|
29 |
+
"0.82": (896, 1088),
|
30 |
+
"0.88": (960, 1088),
|
31 |
+
"0.94": (960, 1024),
|
32 |
+
"1.0": (1024, 1024),
|
33 |
+
"1.07": (1024, 960),
|
34 |
+
"1.13": (1088, 960),
|
35 |
+
"1.21": (1088, 896),
|
36 |
+
"1.29": (1152, 896),
|
37 |
+
"1.38": (1152, 832),
|
38 |
+
"1.46": (1216, 832),
|
39 |
+
"1.67": (1280, 768),
|
40 |
+
"1.75": (1344, 768),
|
41 |
+
"1.91": (1344, 704),
|
42 |
+
"2.0": (1408, 704),
|
43 |
+
"2.09": (1472, 704),
|
44 |
+
"2.4": (1536, 640),
|
45 |
+
"2.5": (1600, 640),
|
46 |
+
"2.89": (1664, 576),
|
47 |
+
"3.0": (1728, 576),
|
48 |
+
}
|
49 |
+
|
50 |
+
def find_closest_aspect_ratio(target_width, target_height):
|
51 |
+
target_ratio = target_width / target_height
|
52 |
+
closest_ratio = None
|
53 |
+
min_difference = float('inf')
|
54 |
+
|
55 |
+
for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
|
56 |
+
ratio = width / height
|
57 |
+
difference = abs(target_ratio - ratio)
|
58 |
+
|
59 |
+
if difference < min_difference:
|
60 |
+
min_difference = difference
|
61 |
+
closest_ratio = ratio_str
|
62 |
+
|
63 |
+
return closest_ratio
|
64 |
+
|
65 |
+
|
66 |
+
def resize_to_closest_aspect_ratio(image):
|
67 |
+
target_width, target_height = image.size
|
68 |
+
closest_ratio = find_closest_aspect_ratio(target_width, target_height)
|
69 |
+
|
70 |
+
# Get the dimensions from the closest aspect ratio in the dictionary
|
71 |
+
new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
|
72 |
+
|
73 |
+
# Resize the image to the new dimensions while preserving the aspect ratio
|
74 |
+
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
75 |
+
|
76 |
+
return resized_image
|
77 |
+
|
78 |
+
|
79 |
ADAPTER_NAMES = [
|
80 |
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
81 |
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
|
|
115 |
return self.model.to(device)
|
116 |
|
117 |
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
118 |
+
return self.model(image, detect_resolution=512, image_resolution=1024)
|
119 |
|
120 |
|
121 |
class MidasPreprocessor(Preprocessor):
|
|
|
331 |
if apply_preprocess:
|
332 |
image = self.preprocessor(image)
|
333 |
|
334 |
+
image = resize_to_closest_aspect_ratio(image)
|
335 |
+
|
336 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
337 |
out = self.pipe(
|
338 |
prompt=prompt,
|