fix
Browse files- .gitignore +1 -1
- app.py +34 -12
- pipelines.py +4 -4
.gitignore
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
-
__pycache__/
|
3 |
*.py[cod]
|
4 |
*$py.class
|
5 |
|
|
|
1 |
# Byte-compiled / optimized / DLL files
|
2 |
+
**/__pycache__/
|
3 |
*.py[cod]
|
4 |
*$py.class
|
5 |
|
app.py
CHANGED
@@ -23,6 +23,17 @@ pipeline = None
|
|
23 |
rembg_session = rembg.new_session()
|
24 |
|
25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
def check_input_image(input_image):
|
27 |
if input_image is None:
|
28 |
raise gr.Error("No image uploaded!")
|
@@ -67,13 +78,18 @@ def add_background(image, bg_color=(255, 255, 255)):
|
|
67 |
return Image.alpha_composite(background, image)
|
68 |
|
69 |
|
70 |
-
def preprocess_image(
|
71 |
"""
|
72 |
input image is a pil image in RGBA, return RGB image
|
73 |
"""
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
76 |
image = do_resize_content(image, foreground_ratio)
|
|
|
77 |
image = add_background(image, backgroud_color)
|
78 |
return image.convert("RGB")
|
79 |
|
@@ -150,8 +166,13 @@ with gr.Blocks() as demo:
|
|
150 |
with gr.Row():
|
151 |
with gr.Column():
|
152 |
with gr.Row():
|
153 |
-
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
155 |
back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
|
156 |
foreground_ratio = gr.Slider(
|
157 |
label="Foreground Ratio",
|
@@ -163,9 +184,13 @@ with gr.Blocks() as demo:
|
|
163 |
|
164 |
with gr.Column():
|
165 |
seed = gr.Number(value=1234, label="seed", precision=0)
|
166 |
-
guidance_scale = gr.Number(value=5.5, minimum=
|
167 |
-
step = gr.Number(value=50, minimum=
|
168 |
text_button = gr.Button("Generate 3D shape")
|
|
|
|
|
|
|
|
|
169 |
with gr.Column():
|
170 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
171 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
@@ -188,14 +213,11 @@ with gr.Blocks() as demo:
|
|
188 |
output_model,
|
189 |
output_obj,
|
190 |
]
|
191 |
-
|
192 |
-
examples=[os.path.join("examples", i) for i in os.listdir("examples")],
|
193 |
-
inputs=[image_input],
|
194 |
-
)
|
195 |
|
196 |
text_button.click(fn=check_input_image, inputs=[image_input]).success(
|
197 |
fn=preprocess_image,
|
198 |
-
inputs=[image_input,
|
199 |
outputs=[processed_image],
|
200 |
).success(
|
201 |
fn=gen_image,
|
|
|
23 |
rembg_session = rembg.new_session()
|
24 |
|
25 |
|
26 |
+
def expand_to_square(image, bg_color=(0, 0, 0, 0)):
|
27 |
+
# expand image to 1:1
|
28 |
+
width, height = image.size
|
29 |
+
if width == height:
|
30 |
+
return image
|
31 |
+
new_size = (max(width, height), max(width, height))
|
32 |
+
new_image = Image.new("RGBA", new_size, bg_color)
|
33 |
+
paste_position = ((new_size[0] - width) // 2, (new_size[1] - height) // 2)
|
34 |
+
new_image.paste(image, paste_position)
|
35 |
+
return new_image
|
36 |
+
|
37 |
def check_input_image(input_image):
|
38 |
if input_image is None:
|
39 |
raise gr.Error("No image uploaded!")
|
|
|
78 |
return Image.alpha_composite(background, image)
|
79 |
|
80 |
|
81 |
+
def preprocess_image(image, background_choice, foreground_ratio, backgroud_color):
|
82 |
"""
|
83 |
input image is a pil image in RGBA, return RGB image
|
84 |
"""
|
85 |
+
print(background_choice)
|
86 |
+
if background_choice == "Alpha as mask":
|
87 |
+
background = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
88 |
+
image = Image.alpha_composite(background, image)
|
89 |
+
else:
|
90 |
+
image = remove_background(image, rembg_session, force_remove=True)
|
91 |
image = do_resize_content(image, foreground_ratio)
|
92 |
+
image = expand_to_square(image)
|
93 |
image = add_background(image, backgroud_color)
|
94 |
return image.convert("RGB")
|
95 |
|
|
|
166 |
with gr.Row():
|
167 |
with gr.Column():
|
168 |
with gr.Row():
|
169 |
+
background_choice = gr.Radio([
|
170 |
+
"Alpha as mask",
|
171 |
+
"Auto Remove background"
|
172 |
+
], value="Alpha as mask",
|
173 |
+
label="backgroud choice")
|
174 |
+
# do_remove_background = gr.Checkbox(label=, value=True)
|
175 |
+
# force_remove = gr.Checkbox(label=, value=False)
|
176 |
back_groud_color = gr.ColorPicker(label="Background Color", value="#7F7F7F", interactive=False)
|
177 |
foreground_ratio = gr.Slider(
|
178 |
label="Foreground Ratio",
|
|
|
184 |
|
185 |
with gr.Column():
|
186 |
seed = gr.Number(value=1234, label="seed", precision=0)
|
187 |
+
guidance_scale = gr.Number(value=5.5, minimum=3, maximum=10, label="guidance_scale")
|
188 |
+
step = gr.Number(value=50, minimum=30, maximum=100, label="sample steps", precision=0)
|
189 |
text_button = gr.Button("Generate 3D shape")
|
190 |
+
gr.Examples(
|
191 |
+
examples=[os.path.join("examples", i) for i in os.listdir("examples")],
|
192 |
+
inputs=[image_input],
|
193 |
+
)
|
194 |
with gr.Column():
|
195 |
image_output = gr.Image(interactive=False, label="Output RGB image")
|
196 |
xyz_ouput = gr.Image(interactive=False, label="Output CCM image")
|
|
|
213 |
output_model,
|
214 |
output_obj,
|
215 |
]
|
216 |
+
|
|
|
|
|
|
|
217 |
|
218 |
text_button.click(fn=check_input_image, inputs=[image_input]).success(
|
219 |
fn=preprocess_image,
|
220 |
+
inputs=[image_input, background_choice, foreground_ratio, back_groud_color],
|
221 |
outputs=[processed_image],
|
222 |
).success(
|
223 |
fn=gen_image,
|
pipelines.py
CHANGED
@@ -92,7 +92,7 @@ class TwoStagePipeline(object):
|
|
92 |
stage1_images.pop(self.stage1_sampler.ref_position)
|
93 |
return stage1_images
|
94 |
|
95 |
-
def stage2_sample(self, pixel_img, stage1_images):
|
96 |
if type(pixel_img) == str:
|
97 |
pixel_img = Image.open(pixel_img)
|
98 |
|
@@ -112,8 +112,8 @@ class TwoStagePipeline(object):
|
|
112 |
self.stage2_sampler.sampler,
|
113 |
pixel_images=stage1_images,
|
114 |
ip=pixel_img,
|
115 |
-
step=
|
116 |
-
scale=
|
117 |
batch_size=self.stage2_sampler.batch_size,
|
118 |
ddim_eta=0.0,
|
119 |
dtype=self.stage2_sampler.dtype,
|
@@ -134,7 +134,7 @@ class TwoStagePipeline(object):
|
|
134 |
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
|
135 |
pixel_img = do_resize_content(pixel_img, self.resize_rate)
|
136 |
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
|
137 |
-
stage2_images = self.stage2_sample(pixel_img, stage1_images)
|
138 |
|
139 |
return {
|
140 |
"ref_img": pixel_img,
|
|
|
92 |
stage1_images.pop(self.stage1_sampler.ref_position)
|
93 |
return stage1_images
|
94 |
|
95 |
+
def stage2_sample(self, pixel_img, stage1_images, scale=5, step=50):
|
96 |
if type(pixel_img) == str:
|
97 |
pixel_img = Image.open(pixel_img)
|
98 |
|
|
|
112 |
self.stage2_sampler.sampler,
|
113 |
pixel_images=stage1_images,
|
114 |
ip=pixel_img,
|
115 |
+
step=step,
|
116 |
+
scale=scale,
|
117 |
batch_size=self.stage2_sampler.batch_size,
|
118 |
ddim_eta=0.0,
|
119 |
dtype=self.stage2_sampler.dtype,
|
|
|
134 |
def __call__(self, pixel_img, prompt="3D assets", scale=5, step=50):
|
135 |
pixel_img = do_resize_content(pixel_img, self.resize_rate)
|
136 |
stage1_images = self.stage1_sample(pixel_img, prompt, scale=scale, step=step)
|
137 |
+
stage2_images = self.stage2_sample(pixel_img, stage1_images, scale=scale, step=step)
|
138 |
|
139 |
return {
|
140 |
"ref_img": pixel_img,
|