raphael-gl HF staff commited on
Commit
56bed35
0 Parent(s):

Duplicate from raphael-gl/ai-days-image-background-substitution

Browse files
Files changed (4) hide show
  1. .gitattributes +35 -0
  2. README.md +13 -0
  3. app.py +165 -0
  4. requirements.txt +13 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Ai Days Image Background Substitution
3
+ emoji: 🔥
4
+ colorFrom: pink
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 3.35.2
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: raphael-gl/ai-days-image-background-substitution
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+ import cv2
6
+ from diffusers import StableDiffusionPipeline
7
+ import gradio as gr
8
+ # import mediapipe as mp
9
+ import numpy as np
10
+ import PIL
11
+ import torch.cuda
12
+ from transformers import pipeline
13
+
14
+ os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
15
+
16
+
17
+ logging.basicConfig(level=logging.INFO,
18
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
19
+ force=True)
20
+
21
+ LOG = logging.getLogger(__name__)
22
+
23
+ LOG.info("Loading image segmentation model")
24
+
25
+ seg_kwargs = {
26
+ "task": "image-segmentation",
27
+ "model": "nvidia/segformer-b0-finetuned-ade-512-512"
28
+ }
29
+
30
+ img_segmentation_model = pipeline(**seg_kwargs)
31
+
32
+
33
+ # mp_selfie_segmentation = mp.solutions.selfie_segmentation
34
+ # img_segmentation_model = mp_selfie_segmentation.SelfieSegmentation(model_selection=0)
35
+
36
+
37
+ LOG.info("Loading diffusion model")
38
+
39
+ diffusion = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
40
+
41
+ if torch.cuda.is_available():
42
+ LOG.info("Moving diffusion model to GPU")
43
+ diffusion.to('cuda')
44
+
45
+
46
+ def image_preprocess(image: PIL.Image):
47
+ LOG.info("Preprocessing image %s", image)
48
+ start = time.time()
49
+ # image = PIL.ImageOps.exif_transpose(image)
50
+ image = image.convert("RGB")
51
+ image = resize_image(image)
52
+ # image = np.array(image)
53
+ # # Convert RGB to BGR
54
+ # image = image[:, :, ::-1].copy()
55
+ elapsed = time.time() - start
56
+ LOG.info("Image preprocessed, %.2f seconds elapsed", elapsed)
57
+ return image
58
+
59
+
60
+ def resize_image(image: PIL.Image):
61
+ width, height = image.size
62
+ ratio = max(width / 512, height / 512)
63
+ width = int(width / ratio) // 8 * 8
64
+ height = int(height / ratio) // 8 * 8
65
+ image = image.resize((width, height))
66
+ return image
67
+
68
+
69
+ def extract_selfie_mask(threshold, image):
70
+ LOG.info("Extracting selfie mask")
71
+ start = time.time()
72
+ segments = img_segmentation_model(image)
73
+ kept = None
74
+ for s in segments:
75
+ if s['score'] is None:
76
+ s['score'] = 1
77
+ if s['label'] == 'person' and s['score'] > 0.99:
78
+ if not kept:
79
+ kept = s
80
+ elif kept['score'] < s['score']:
81
+ kept = s
82
+ if not kept:
83
+ LOG.info("No person found in the photo, skipping")
84
+ mask = np.zeros((image.size[1], image.size[0], 3), dtype='float32')
85
+ else:
86
+ mask = kept['mask']
87
+ mask = np.array(mask, dtype='float32')
88
+
89
+ cv2.threshold(mask, threshold, 1, cv2.THRESH_BINARY, dst=mask)
90
+ cv2.dilate(mask, np.ones((5, 5), np.uint8), iterations=1, dst=mask)
91
+ cv2.blur(mask, (10, 10), dst=mask)
92
+
93
+ elapsed = time.time() - start
94
+ LOG.info("Selfie extracted, %.2f seconds elapsed", elapsed)
95
+ return mask
96
+
97
+
98
+ def generate_background(prompt, num_inference_steps, height, width):
99
+ LOG.info("Generating background")
100
+ start = time.time()
101
+ background = diffusion(
102
+ prompt=prompt,
103
+ num_inference_steps=int(num_inference_steps),
104
+ height=height,
105
+ width=width
106
+ )
107
+ nsfw = background.nsfw_content_detected[0]
108
+ background = background.images[0]
109
+
110
+ if nsfw:
111
+ LOG.info('NSFW detected, skipping')
112
+ background = np.zeros((height, width, 3), dtype='uint8')
113
+ else:
114
+ background = np.array(background)
115
+ # Convert RGB to BGR
116
+ background = background[:, :, ::-1].copy()
117
+
118
+ elapsed = time.time() - start
119
+ LOG.info("Background generated, elapsed %.2f seconds", elapsed)
120
+ return background
121
+
122
+
123
+ def merge_selfie_and_background(selfie, background, mask):
124
+ LOG.info("Merging extracted selfie and generated background")
125
+ selfie = np.array(selfie)
126
+ # Convert RGB to BGR
127
+ selfie = selfie[:, :, ::-1].copy()
128
+ cv2.blendLinear(selfie, background, mask, 1 - mask, dst=selfie)
129
+ selfie = cv2.cvtColor(selfie, cv2.COLOR_BGR2RGB)
130
+ selfie = PIL.Image.fromarray(selfie)
131
+ return selfie
132
+
133
+
134
+ def demo(threshold, image, prompt, num_inference_steps):
135
+ LOG.info("Processing image")
136
+ try:
137
+ image = image_preprocess(image)
138
+ mask = extract_selfie_mask(threshold, image)
139
+ background = generate_background(prompt, num_inference_steps,
140
+ image.size[1], image.size[0])
141
+ output = merge_selfie_and_background(image, background, mask)
142
+ except Exception as e:
143
+ LOG.error("Some unexpected error occured")
144
+ LOG.exception(e)
145
+ raise
146
+ return output
147
+
148
+
149
+ iface = gr.Interface(
150
+ fn=demo,
151
+ inputs=[
152
+ gr.Slider(minimum=0.1, maximum=1, step=0.05, label="Selfie segmentation threshold",
153
+ value=0.8),
154
+ gr.Image(type='pil', label="Upload your selfie"),
155
+ gr.Text(value="a photo of the Eiffel tower on the right side",
156
+ label="Background description"),
157
+ gr.Slider(minimum=5, maximum=100, step=5, label="Diffusion inference steps",
158
+ value=50)
159
+ ],
160
+ outputs=[
161
+ gr.Image(label="Invent yourself a life :)")
162
+ ])
163
+
164
+ # iface.launch(server_name="0.0.0.0", server_port=6443)
165
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ opencv-python
3
+ pillow
4
+ timm
5
+ mediapipe
6
+ diffusers
7
+ transformers
8
+ scipy
9
+ ftfy
10
+ accelerate
11
+ torch
12
+ numpy
13
+ hf_transfer