Spaces:
Runtime error
Runtime error
sandrawang1031
commited on
Commit
•
eca813c
1
Parent(s):
dce3dbb
init
Browse files- .gitignore +138 -0
- app.py +45 -0
- model.py +147 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
pip-wheel-metadata/
|
24 |
+
share/python-wheels/
|
25 |
+
*.egg-info/
|
26 |
+
.installed.cfg
|
27 |
+
*.egg
|
28 |
+
MANIFEST
|
29 |
+
.docker/
|
30 |
+
|
31 |
+
# PyInstaller
|
32 |
+
# Usually these files are written by a python script from a template
|
33 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
34 |
+
*.manifest
|
35 |
+
*.spec
|
36 |
+
|
37 |
+
# Installer logs
|
38 |
+
pip-log.txt
|
39 |
+
pip-delete-this-directory.txt
|
40 |
+
|
41 |
+
# Unit test / coverage reports
|
42 |
+
htmlcov/
|
43 |
+
.tox/
|
44 |
+
.nox/
|
45 |
+
.coverage
|
46 |
+
.coverage.*
|
47 |
+
.cache
|
48 |
+
nosetests.xml
|
49 |
+
coverage.xml
|
50 |
+
*.cover
|
51 |
+
*.py,cover
|
52 |
+
.hypothesis/
|
53 |
+
.pytest_cache/
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pipenv
|
86 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
87 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
88 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
89 |
+
# install all needed dependencies.
|
90 |
+
#Pipfile.lock
|
91 |
+
|
92 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
93 |
+
__pypackages__/
|
94 |
+
|
95 |
+
# Celery stuff
|
96 |
+
celerybeat-schedule
|
97 |
+
celerybeat.pid
|
98 |
+
|
99 |
+
# SageMath parsed files
|
100 |
+
*.sage.py
|
101 |
+
|
102 |
+
# Environments
|
103 |
+
.env
|
104 |
+
# direnv
|
105 |
+
.envrc
|
106 |
+
.venv
|
107 |
+
env/
|
108 |
+
venv/
|
109 |
+
ENV/
|
110 |
+
env.bak/
|
111 |
+
venv.bak/
|
112 |
+
docker-compose-interpreter-local.yml
|
113 |
+
|
114 |
+
# Spyder project settings
|
115 |
+
.spyderproject
|
116 |
+
.spyproject
|
117 |
+
|
118 |
+
# Rope project settings
|
119 |
+
.ropeproject
|
120 |
+
|
121 |
+
# mkdocs documentation
|
122 |
+
/site
|
123 |
+
|
124 |
+
# mypy
|
125 |
+
.mypy_cache/
|
126 |
+
.dmypy.json
|
127 |
+
dmypy.json
|
128 |
+
|
129 |
+
# Pyre type checker
|
130 |
+
.pyre/
|
131 |
+
|
132 |
+
# ide
|
133 |
+
.idea/
|
134 |
+
.vscode/
|
135 |
+
|
136 |
+
# macos
|
137 |
+
.DS_Store
|
138 |
+
.envrc
|
app.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import os
|
3 |
+
|
4 |
+
from model import VirtualStagingToolV2
|
5 |
+
|
6 |
+
|
7 |
+
def predict(image, style, color_preference):
|
8 |
+
init_image = image.convert("RGB").resize((512, 512))
|
9 |
+
# mask = dict["mask"].convert("RGB").resize((512, 512))
|
10 |
+
|
11 |
+
vs_tool = VirtualStagingToolV2(diffusion_version="stabilityai/stable-diffusion-2-inpainting")
|
12 |
+
output_images, transparent_mask_image = vs_tool.virtual_stage(
|
13 |
+
image=init_image, style=style, color_preference=color_preference, number_images=1)
|
14 |
+
return output_images[0], transparent_mask_image, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
|
15 |
+
|
16 |
+
|
17 |
+
image_blocks = gr.Blocks()
|
18 |
+
with image_blocks as demo:
|
19 |
+
with gr.Group():
|
20 |
+
with gr.Box():
|
21 |
+
with gr.Row():
|
22 |
+
with gr.Column():
|
23 |
+
image = gr.Image(source='upload', elem_id="image_upload",
|
24 |
+
type="pil", label="Upload",
|
25 |
+
).style(height=400)
|
26 |
+
with gr.Row(elem_id="prompt-container").style(mobile_collapse=False, equal_height=True):
|
27 |
+
style = gr.Dropdown(
|
28 |
+
["Mordern", "Coastal", "French country"],
|
29 |
+
label="Design theme", elem_id="input-color"
|
30 |
+
)
|
31 |
+
|
32 |
+
color_preference = gr.Textbox(placeholder='Enter color preference',
|
33 |
+
label="Color preference", elem_id="input-color")
|
34 |
+
btn = gr.Button("Inpaint!").style(
|
35 |
+
margin=False,
|
36 |
+
rounded=(False, True, True, False),
|
37 |
+
full_width=False,
|
38 |
+
)
|
39 |
+
with gr.Column():
|
40 |
+
mask_image = gr.Image(label="Mask image", elem_id="mask-img").style(height=400)
|
41 |
+
image_out = gr.Image(label="Output", elem_id="output-img").style(height=400)
|
42 |
+
|
43 |
+
btn.click(fn=predict, inputs=[image, style, color_preference], outputs=[image_out, mask_image])
|
44 |
+
|
45 |
+
image_blocks.launch()
|
model.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import defaultdict
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import matplotlib.patches as mpatches
|
4 |
+
from matplotlib import cm
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
import torch
|
9 |
+
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation
|
10 |
+
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
|
11 |
+
from diffusers import StableDiffusionInpaintPipeline
|
12 |
+
|
13 |
+
|
14 |
+
class VirtualStagingToolV2():
|
15 |
+
|
16 |
+
def __init__(self,
|
17 |
+
segmentation_version='openmmlab/upernet-convnext-tiny',
|
18 |
+
diffusion_version="stabilityai/stable-diffusion-2-inpainting"
|
19 |
+
):
|
20 |
+
|
21 |
+
self.segmentation_version = segmentation_version
|
22 |
+
self.diffusion_version = diffusion_version
|
23 |
+
|
24 |
+
self.feature_extractor = AutoImageProcessor.from_pretrained(self.segmentation_version)
|
25 |
+
self.segmentation_model = UperNetForSemanticSegmentation.from_pretrained(self.segmentation_version)
|
26 |
+
|
27 |
+
self.diffution_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
28 |
+
self.diffusion_version,
|
29 |
+
torch_dtype=torch.float32,
|
30 |
+
)
|
31 |
+
self.diffution_pipeline = self.diffution_pipeline.to("cpu")
|
32 |
+
|
33 |
+
def _predict(self, image):
|
34 |
+
inputs = self.feature_extractor(images=image, return_tensors="pt")
|
35 |
+
outputs = self.segmentation_model(**inputs)
|
36 |
+
prediction = \
|
37 |
+
self.feature_extractor.post_process_semantic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
|
38 |
+
return prediction
|
39 |
+
|
40 |
+
def _save_mask(self, img, prediction_array, mask_items=[]):
|
41 |
+
mask = np.zeros_like(prediction_array, dtype=np.uint8)
|
42 |
+
|
43 |
+
mask[np.isin(prediction_array, mask_items)] = 0
|
44 |
+
mask[~np.isin(prediction_array, mask_items)] = 255
|
45 |
+
|
46 |
+
# # # Create a PIL Image object from the mask
|
47 |
+
mask_image = Image.fromarray(mask, mode='L')
|
48 |
+
# display(mask_image)
|
49 |
+
|
50 |
+
# mask_image = mask_image.resize((512, 512))
|
51 |
+
# mask_image.save(".tmp/mask_1.png", "PNG")
|
52 |
+
# img = img.resize((512, 512))
|
53 |
+
# img.save(".tmp/input_1.png", "PNG")
|
54 |
+
return mask_image
|
55 |
+
|
56 |
+
def _save_transparent_mask(self, img, prediction_array, mask_items=[]):
|
57 |
+
mask = np.array(img)
|
58 |
+
mask[~np.isin(prediction_array, mask_items), :] = 255
|
59 |
+
mask_image = Image.fromarray(mask).convert('RGBA')
|
60 |
+
|
61 |
+
# Set the transparency of the pixels corresponding to object 1 to 0 (fully transparent)
|
62 |
+
mask_data = mask_image.getdata()
|
63 |
+
mask_data = [(r, g, b, 0) if r == 255 else (r, g, b, 255) for (r, g, b, a) in mask_data]
|
64 |
+
mask_image.putdata(mask_data)
|
65 |
+
|
66 |
+
return mask_image
|
67 |
+
|
68 |
+
def get_mask(self, image_path=None, image=None):
|
69 |
+
if image_path:
|
70 |
+
image = Image.open(image_path)
|
71 |
+
else:
|
72 |
+
if not image:
|
73 |
+
raise ValueError("no image provided")
|
74 |
+
|
75 |
+
# display(image)
|
76 |
+
prediction = self._predict(image)
|
77 |
+
|
78 |
+
label_ids = np.unique(prediction)
|
79 |
+
|
80 |
+
mask_items = [0, 3, 5, 8, 14]
|
81 |
+
|
82 |
+
if 1 in label_ids or 25 in label_ids:
|
83 |
+
mask_items = [1, 2, 4, 25, 32]
|
84 |
+
room = 'backyard'
|
85 |
+
elif 73 in label_ids or 50 in label_ids or 61 in label_ids:
|
86 |
+
mask_items = [0, 3, 5, 8, 14, 50, 61, 71, 118, 124, 129
|
87 |
+
]
|
88 |
+
room = 'kitchen'
|
89 |
+
elif 37 in label_ids or 65 in label_ids or (27 in label_ids and 47 in label_ids and 70 in label_ids):
|
90 |
+
mask_items = [0, 3, 5, 8, 14, 27, 65]
|
91 |
+
room = 'bathroom'
|
92 |
+
elif 7 in label_ids:
|
93 |
+
room = 'bedroom'
|
94 |
+
elif 23 in label_ids or 49 in label_ids:
|
95 |
+
room = 'living room'
|
96 |
+
|
97 |
+
label_ids_without_mask = [i for i in label_ids if i not in mask_items]
|
98 |
+
|
99 |
+
items = [self.segmentation_model.config.id2label[i] for i in label_ids_without_mask]
|
100 |
+
|
101 |
+
mask_image = self._save_mask(image, prediction, mask_items)
|
102 |
+
transparent_mask_image = self._save_transparent_mask(image, prediction, mask_items)
|
103 |
+
return mask_image, transparent_mask_image, image, items, room
|
104 |
+
|
105 |
+
def _edit_image(self, init_image, mask_image, prompt, # height, width,
|
106 |
+
number_images=1):
|
107 |
+
|
108 |
+
init_image = init_image.resize((512, 512)).convert("RGB")
|
109 |
+
mask_image = mask_image.resize((512, 512)).convert("RGB")
|
110 |
+
|
111 |
+
display(init_image)
|
112 |
+
display(mask_image)
|
113 |
+
|
114 |
+
output_images = self.diffution_pipeline(
|
115 |
+
prompt=prompt, image=init_image, mask_image=mask_image,
|
116 |
+
# width=width, height=height,
|
117 |
+
num_images_per_prompt=number_images).images
|
118 |
+
# display(output_image)
|
119 |
+
return output_images
|
120 |
+
|
121 |
+
def virtual_stage(self, image_path=None, image=None, style=None, color_preference=None, number_images=1):
|
122 |
+
mask_image, transparent_mask_image, init_image, items, room = self.get_mask(image_path, image)
|
123 |
+
if not style:
|
124 |
+
raise ValueError('style not provided.')
|
125 |
+
if not color_preference:
|
126 |
+
raise ValueError('color_preference not provided.')
|
127 |
+
|
128 |
+
if room == 'kitchen':
|
129 |
+
items = [i for i in items if i in ['kitchen island', 'cabinet', 'shelf', 'counter', 'countertop', 'stool']]
|
130 |
+
elif room == 'bedroom':
|
131 |
+
items = [i for i in items if i in ['bed', 'table', 'chest of drawers', 'desk', 'armchair', 'wardrobe']]
|
132 |
+
elif room == 'bathroom':
|
133 |
+
items = [i for i in items if
|
134 |
+
i in ['shower', 'bathtub', 'chest of drawers', 'counter', 'countertop', 'sink']]
|
135 |
+
|
136 |
+
items = ', '.join(items)
|
137 |
+
prompt = f'{items}, high resolution, in the {style} style {room} in {color_preference}'
|
138 |
+
print(prompt)
|
139 |
+
|
140 |
+
output_images = self._edit_image(init_image, mask_image, prompt, number_images)
|
141 |
+
|
142 |
+
final_output_images = []
|
143 |
+
for output_image in output_images:
|
144 |
+
display(output_image)
|
145 |
+
output_image = output_image.resize(init_image.size)
|
146 |
+
final_output_images.append(output_image)
|
147 |
+
return final_output_images, transparent_mask_image
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.29.0
|
2 |
+
torch==1.11.0
|
3 |
+
diffusers==0.16.1
|
4 |
+
accelerate==0.19.0
|
5 |
+
matplotlib==3.6.2
|
6 |
+
pillow==9.2.0
|