Ashoka74 commited on
Commit
5ffee47
1 Parent(s): 07aff55

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +1301 -0
app.py ADDED
@@ -0,0 +1,1301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import math
4
+ import gradio as gr
5
+ import numpy as np
6
+ import torch
7
+ import safetensors.torch as sf
8
+ import db_examples
9
+ import datetime
10
+ from pathlib import Path
11
+ from io import BytesIO
12
+
13
+ from PIL import Image
14
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
15
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
16
+ from diffusers.models.attention_processor import AttnProcessor2_0
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+ import dds_cloudapi_sdk
19
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
20
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
21
+ from enum import Enum
22
+ from torch.hub import download_url_to_file
23
+ import tempfile
24
+
25
+ from sam2.build_sam import build_sam2
26
+
27
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
28
+ import cv2
29
+
30
+ from typing import Optional
31
+
32
+ from depth_anything_v2.dpt import DepthAnythingV2
33
+
34
+ import httpx
35
+
36
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
37
+
38
+
39
+
40
+ import supervision as sv
41
+ import torch
42
+ from PIL import Image
43
+
44
+
45
+
46
+ try:
47
+ import xformers
48
+ import xformers.ops
49
+ XFORMERS_AVAILABLE = True
50
+ print("xformers is available - Using memory efficient attention")
51
+ except ImportError:
52
+ XFORMERS_AVAILABLE = False
53
+ print("xformers not available - Using default attention")
54
+
55
+ # Memory optimizations for RTX 2070
56
+ torch.backends.cudnn.benchmark = True
57
+ if torch.cuda.is_available():
58
+ torch.backends.cuda.matmul.allow_tf32 = True
59
+ torch.backends.cudnn.allow_tf32 = True
60
+ # Set a smaller attention slice size for RTX 2070
61
+ torch.backends.cuda.max_split_size_mb = 512
62
+ device = torch.device('cuda')
63
+ else:
64
+ device = torch.device('cpu')
65
+
66
+ # 'stablediffusionapi/realistic-vision-v51'
67
+ # 'runwayml/stable-diffusion-v1-5'
68
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
69
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
70
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
71
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
72
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
73
+ # Load model directly
74
+ from transformers import AutoModelForImageSegmentation
75
+ rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-2.0", trust_remote_code=True)
76
+
77
+
78
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
79
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
80
+ model = model.to(device)
81
+ model.eval()
82
+
83
+ # Change UNet
84
+
85
+ with torch.no_grad():
86
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
87
+ new_conv_in.weight.zero_()
88
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
89
+ new_conv_in.bias = unet.conv_in.bias
90
+ unet.conv_in = new_conv_in
91
+
92
+
93
+ unet_original_forward = unet.forward
94
+
95
+
96
+ def enable_efficient_attention():
97
+ if XFORMERS_AVAILABLE:
98
+ try:
99
+ # RTX 2070 specific settings
100
+ unet.set_use_memory_efficient_attention_xformers(True)
101
+ vae.set_use_memory_efficient_attention_xformers(True)
102
+ print("Enabled xformers memory efficient attention")
103
+ except Exception as e:
104
+ print(f"Xformers error: {e}")
105
+ print("Falling back to sliced attention")
106
+ # Use sliced attention for RTX 2070
107
+ unet.set_attention_slice_size(4)
108
+ vae.set_attention_slice_size(4)
109
+ unet.set_attn_processor(AttnProcessor2_0())
110
+ vae.set_attn_processor(AttnProcessor2_0())
111
+ else:
112
+ # Fallback for when xformers is not available
113
+ print("Using sliced attention")
114
+ unet.set_attention_slice_size(4)
115
+ vae.set_attention_slice_size(4)
116
+ unet.set_attn_processor(AttnProcessor2_0())
117
+ vae.set_attn_processor(AttnProcessor2_0())
118
+
119
+ # Add memory clearing function
120
+ def clear_memory():
121
+ if torch.cuda.is_available():
122
+ torch.cuda.empty_cache()
123
+ torch.cuda.synchronize()
124
+
125
+ # Enable efficient attention
126
+ enable_efficient_attention()
127
+
128
+
129
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
130
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
131
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
132
+ new_sample = torch.cat([sample, c_concat], dim=1)
133
+ kwargs['cross_attention_kwargs'] = {}
134
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
135
+
136
+
137
+ unet.forward = hooked_unet_forward
138
+
139
+ # Load
140
+
141
+ # Model paths
142
+ model_path = './models/iclight_sd15_fc.safetensors'
143
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
144
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
145
+ model_path4 = './checkpoints/config.json'
146
+ model_path5 = './checkpoints/preprocessor_config.json'
147
+ model_path6 = './configs/sam2_hiera_l.yaml'
148
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
149
+
150
+ # Base URL for the repository
151
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
152
+
153
+ # Model URLs
154
+ model_urls = {
155
+ model_path: 'iclight_sd15_fc.safetensors',
156
+ model_path2: 'depth_anything_v2_vits.pth',
157
+ model_path3: 'sam2_hiera_large.pt',
158
+ model_path4: 'config.json',
159
+ model_path5: 'preprocessor_config.json',
160
+ model_path6: 'sam2_hiera_l.yaml',
161
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
162
+ }
163
+
164
+ # Ensure directories exist
165
+ def ensure_directories():
166
+ for path in model_urls.keys():
167
+ os.makedirs(os.path.dirname(path), exist_ok=True)
168
+
169
+ # Download models
170
+ def download_models():
171
+ for local_path, filename in model_urls.items():
172
+ if not os.path.exists(local_path):
173
+ try:
174
+ url = f"{BASE_URL}{filename}"
175
+ print(f"Downloading {filename}")
176
+ download_url_to_file(url, local_path)
177
+ print(f"Successfully downloaded {filename}")
178
+ except Exception as e:
179
+ print(f"Error downloading {filename}: {e}")
180
+
181
+ ensure_directories()
182
+
183
+ download_models()
184
+
185
+
186
+
187
+ sd_offset = sf.load_file(model_path)
188
+ sd_origin = unet.state_dict()
189
+ keys = sd_origin.keys()
190
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
191
+ unet.load_state_dict(sd_merged, strict=True)
192
+ del sd_offset, sd_origin, sd_merged, keys
193
+
194
+ # Device
195
+
196
+ # device = torch.device('cuda')
197
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
198
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
199
+ # unet = unet.to(device=device, dtype=torch.float16)
200
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
201
+
202
+
203
+ # Device and dtype setup
204
+ device = torch.device('cuda')
205
+ dtype = torch.float16 # RTX 2070 works well with float16
206
+
207
+ # Memory optimizations for RTX 2070
208
+ torch.backends.cudnn.benchmark = True
209
+ if torch.cuda.is_available():
210
+ torch.backends.cuda.matmul.allow_tf32 = True
211
+ torch.backends.cudnn.allow_tf32 = True
212
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
213
+ torch.backends.cuda.max_split_size_mb = 128
214
+
215
+ # Move models to device with consistent dtype
216
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
217
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
218
+ unet = unet.to(device=device, dtype=dtype)
219
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
220
+
221
+
222
+ ddim_scheduler = DDIMScheduler(
223
+ num_train_timesteps=1000,
224
+ beta_start=0.00085,
225
+ beta_end=0.012,
226
+ beta_schedule="scaled_linear",
227
+ clip_sample=False,
228
+ set_alpha_to_one=False,
229
+ steps_offset=1,
230
+ )
231
+
232
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
233
+ num_train_timesteps=1000,
234
+ beta_start=0.00085,
235
+ beta_end=0.012,
236
+ steps_offset=1
237
+ )
238
+
239
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
240
+ num_train_timesteps=1000,
241
+ beta_start=0.00085,
242
+ beta_end=0.012,
243
+ algorithm_type="sde-dpmsolver++",
244
+ use_karras_sigmas=True,
245
+ steps_offset=1
246
+ )
247
+
248
+ # Pipelines
249
+
250
+ t2i_pipe = StableDiffusionPipeline(
251
+ vae=vae,
252
+ text_encoder=text_encoder,
253
+ tokenizer=tokenizer,
254
+ unet=unet,
255
+ scheduler=dpmpp_2m_sde_karras_scheduler,
256
+ safety_checker=None,
257
+ requires_safety_checker=False,
258
+ feature_extractor=None,
259
+ image_encoder=None
260
+ )
261
+
262
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
263
+ vae=vae,
264
+ text_encoder=text_encoder,
265
+ tokenizer=tokenizer,
266
+ unet=unet,
267
+ scheduler=dpmpp_2m_sde_karras_scheduler,
268
+ safety_checker=None,
269
+ requires_safety_checker=False,
270
+ feature_extractor=None,
271
+ image_encoder=None
272
+ )
273
+
274
+
275
+ @torch.inference_mode()
276
+ def encode_prompt_inner(txt: str):
277
+ max_length = tokenizer.model_max_length
278
+ chunk_length = tokenizer.model_max_length - 2
279
+ id_start = tokenizer.bos_token_id
280
+ id_end = tokenizer.eos_token_id
281
+ id_pad = id_end
282
+
283
+ def pad(x, p, i):
284
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
285
+
286
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
287
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
288
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
289
+
290
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
291
+ conds = text_encoder(token_ids).last_hidden_state
292
+
293
+ return conds
294
+
295
+
296
+ @torch.inference_mode()
297
+ def encode_prompt_pair(positive_prompt, negative_prompt):
298
+ c = encode_prompt_inner(positive_prompt)
299
+ uc = encode_prompt_inner(negative_prompt)
300
+
301
+ c_len = float(len(c))
302
+ uc_len = float(len(uc))
303
+ max_count = max(c_len, uc_len)
304
+ c_repeat = int(math.ceil(max_count / c_len))
305
+ uc_repeat = int(math.ceil(max_count / uc_len))
306
+ max_chunk = max(len(c), len(uc))
307
+
308
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
309
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
310
+
311
+ c = torch.cat([p[None, ...] for p in c], dim=1)
312
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
313
+
314
+ return c, uc
315
+
316
+
317
+ @torch.inference_mode()
318
+ def pytorch2numpy(imgs, quant=True):
319
+ results = []
320
+ for x in imgs:
321
+ y = x.movedim(0, -1)
322
+
323
+ if quant:
324
+ y = y * 127.5 + 127.5
325
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
326
+ else:
327
+ y = y * 0.5 + 0.5
328
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
329
+
330
+ results.append(y)
331
+ return results
332
+
333
+
334
+ @torch.inference_mode()
335
+ def numpy2pytorch(imgs):
336
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
337
+ h = h.movedim(-1, 1)
338
+ return h
339
+
340
+
341
+ def resize_and_center_crop(image, target_width, target_height):
342
+ pil_image = Image.fromarray(image)
343
+ original_width, original_height = pil_image.size
344
+ scale_factor = max(target_width / original_width, target_height / original_height)
345
+ resized_width = int(round(original_width * scale_factor))
346
+ resized_height = int(round(original_height * scale_factor))
347
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
348
+ left = (resized_width - target_width) / 2
349
+ top = (resized_height - target_height) / 2
350
+ right = (resized_width + target_width) / 2
351
+ bottom = (resized_height + target_height) / 2
352
+ cropped_image = resized_image.crop((left, top, right, bottom))
353
+ return np.array(cropped_image)
354
+
355
+
356
+ def resize_without_crop(image, target_width, target_height):
357
+ pil_image = Image.fromarray(image)
358
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
359
+ return np.array(resized_image)
360
+
361
+
362
+ @torch.inference_mode()
363
+ def run_rmbg(img, sigma=0.0):
364
+ # Convert RGBA to RGB if needed
365
+ if img.shape[-1] == 4:
366
+ # Use white background for alpha composition
367
+ alpha = img[..., 3:] / 255.0
368
+ rgb = img[..., :3]
369
+ white_bg = np.ones_like(rgb) * 255
370
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
371
+
372
+ H, W, C = img.shape
373
+ assert C == 3
374
+ k = (256.0 / float(H * W)) ** 0.5
375
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
376
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
377
+ alpha = rmbg(feed)[0][0]
378
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
379
+ alpha = alpha.movedim(1, -1)[0]
380
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
381
+
382
+ # Create RGBA image
383
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
384
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
385
+ return result.clip(0, 255).astype(np.uint8), rgba
386
+ @torch.inference_mode()
387
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
388
+ clear_memory()
389
+
390
+ # Get input dimensions
391
+ input_height, input_width = input_fg.shape[:2]
392
+
393
+ bg_source = BGSource(bg_source)
394
+
395
+
396
+ if bg_source == BGSource.UPLOAD:
397
+ pass
398
+ elif bg_source == BGSource.UPLOAD_FLIP:
399
+ input_bg = np.fliplr(input_bg)
400
+ elif bg_source == BGSource.GREY:
401
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
402
+ elif bg_source == BGSource.LEFT:
403
+ gradient = np.linspace(255, 0, input_width)
404
+ image = np.tile(gradient, (input_height, 1))
405
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
406
+ elif bg_source == BGSource.RIGHT:
407
+ gradient = np.linspace(0, 255, input_width)
408
+ image = np.tile(gradient, (input_height, 1))
409
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
410
+ elif bg_source == BGSource.TOP:
411
+ gradient = np.linspace(255, 0, input_height)[:, None]
412
+ image = np.tile(gradient, (1, input_width))
413
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
414
+ elif bg_source == BGSource.BOTTOM:
415
+ gradient = np.linspace(0, 255, input_height)[:, None]
416
+ image = np.tile(gradient, (1, input_width))
417
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
418
+ else:
419
+ raise 'Wrong initial latent!'
420
+
421
+ rng = torch.Generator(device=device).manual_seed(int(seed))
422
+
423
+ # Use input dimensions directly
424
+ fg = resize_without_crop(input_fg, input_width, input_height)
425
+
426
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
427
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
428
+
429
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
430
+
431
+ if input_bg is None:
432
+ latents = t2i_pipe(
433
+ prompt_embeds=conds,
434
+ negative_prompt_embeds=unconds,
435
+ width=input_width,
436
+ height=input_height,
437
+ num_inference_steps=steps,
438
+ num_images_per_prompt=num_samples,
439
+ generator=rng,
440
+ output_type='latent',
441
+ guidance_scale=cfg,
442
+ cross_attention_kwargs={'concat_conds': concat_conds},
443
+ ).images.to(vae.dtype) / vae.config.scaling_factor
444
+ else:
445
+ bg = resize_without_crop(input_bg, input_width, input_height)
446
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
447
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
448
+ latents = i2i_pipe(
449
+ image=bg_latent,
450
+ strength=lowres_denoise,
451
+ prompt_embeds=conds,
452
+ negative_prompt_embeds=unconds,
453
+ width=input_width,
454
+ height=input_height,
455
+ num_inference_steps=int(round(steps / lowres_denoise)),
456
+ num_images_per_prompt=num_samples,
457
+ generator=rng,
458
+ output_type='latent',
459
+ guidance_scale=cfg,
460
+ cross_attention_kwargs={'concat_conds': concat_conds},
461
+ ).images.to(vae.dtype) / vae.config.scaling_factor
462
+
463
+ pixels = vae.decode(latents).sample
464
+ pixels = pytorch2numpy(pixels)
465
+ pixels = [resize_without_crop(
466
+ image=p,
467
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
468
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
469
+ for p in pixels]
470
+
471
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
472
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
473
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
474
+
475
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
476
+
477
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
478
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
479
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
480
+
481
+ latents = i2i_pipe(
482
+ image=latents,
483
+ strength=highres_denoise,
484
+ prompt_embeds=conds,
485
+ negative_prompt_embeds=unconds,
486
+ width=highres_width,
487
+ height=highres_height,
488
+ num_inference_steps=int(round(steps / highres_denoise)),
489
+ num_images_per_prompt=num_samples,
490
+ generator=rng,
491
+ output_type='latent',
492
+ guidance_scale=cfg,
493
+ cross_attention_kwargs={'concat_conds': concat_conds},
494
+ ).images.to(vae.dtype) / vae.config.scaling_factor
495
+
496
+ pixels = vae.decode(latents).sample
497
+ pixels = pytorch2numpy(pixels)
498
+
499
+ # Resize back to input dimensions
500
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
501
+ pixels = np.stack(pixels)
502
+
503
+ return pixels
504
+
505
+ @torch.inference_mode()
506
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
507
+ clear_memory()
508
+ bg_source = BGSource(bg_source)
509
+
510
+ if bg_source == BGSource.UPLOAD:
511
+ pass
512
+ elif bg_source == BGSource.UPLOAD_FLIP:
513
+ input_bg = np.fliplr(input_bg)
514
+ elif bg_source == BGSource.GREY:
515
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
516
+ elif bg_source == BGSource.LEFT:
517
+ gradient = np.linspace(224, 32, image_width)
518
+ image = np.tile(gradient, (image_height, 1))
519
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
520
+ elif bg_source == BGSource.RIGHT:
521
+ gradient = np.linspace(32, 224, image_width)
522
+ image = np.tile(gradient, (image_height, 1))
523
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
524
+ elif bg_source == BGSource.TOP:
525
+ gradient = np.linspace(224, 32, image_height)[:, None]
526
+ image = np.tile(gradient, (1, image_width))
527
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
528
+ elif bg_source == BGSource.BOTTOM:
529
+ gradient = np.linspace(32, 224, image_height)[:, None]
530
+ image = np.tile(gradient, (1, image_width))
531
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
532
+ else:
533
+ raise 'Wrong background source!'
534
+
535
+ rng = torch.Generator(device=device).manual_seed(seed)
536
+
537
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
538
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
539
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
540
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
541
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
542
+
543
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
544
+
545
+ latents = t2i_pipe(
546
+ prompt_embeds=conds,
547
+ negative_prompt_embeds=unconds,
548
+ width=image_width,
549
+ height=image_height,
550
+ num_inference_steps=steps,
551
+ num_images_per_prompt=num_samples,
552
+ generator=rng,
553
+ output_type='latent',
554
+ guidance_scale=cfg,
555
+ cross_attention_kwargs={'concat_conds': concat_conds},
556
+ ).images.to(vae.dtype) / vae.config.scaling_factor
557
+
558
+ pixels = vae.decode(latents).sample
559
+ pixels = pytorch2numpy(pixels)
560
+ pixels = [resize_without_crop(
561
+ image=p,
562
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
563
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
564
+ for p in pixels]
565
+
566
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
567
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
568
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
569
+
570
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
571
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
572
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
573
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
574
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
575
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
576
+
577
+ latents = i2i_pipe(
578
+ image=latents,
579
+ strength=highres_denoise,
580
+ prompt_embeds=conds,
581
+ negative_prompt_embeds=unconds,
582
+ width=image_width,
583
+ height=image_height,
584
+ num_inference_steps=int(round(steps / highres_denoise)),
585
+ num_images_per_prompt=num_samples,
586
+ generator=rng,
587
+ output_type='latent',
588
+ guidance_scale=cfg,
589
+ cross_attention_kwargs={'concat_conds': concat_conds},
590
+ ).images.to(vae.dtype) / vae.config.scaling_factor
591
+
592
+ pixels = vae.decode(latents).sample
593
+ pixels = pytorch2numpy(pixels, quant=False)
594
+
595
+ clear_memory()
596
+ return pixels, [fg, bg]
597
+
598
+
599
+ @torch.inference_mode()
600
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
601
+ input_fg, matting = run_rmbg(input_fg)
602
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
603
+ return input_fg, results
604
+
605
+
606
+
607
+ @torch.inference_mode()
608
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
609
+ bg_source = BGSource(bg_source)
610
+
611
+ # Convert numerical inputs to appropriate types
612
+ image_width = int(image_width)
613
+ image_height = int(image_height)
614
+ num_samples = int(num_samples)
615
+ seed = int(seed)
616
+ steps = int(steps)
617
+ cfg = float(cfg)
618
+ highres_scale = float(highres_scale)
619
+ highres_denoise = float(highres_denoise)
620
+
621
+ if bg_source == BGSource.UPLOAD:
622
+ pass
623
+ elif bg_source == BGSource.UPLOAD_FLIP:
624
+ input_bg = np.fliplr(input_bg)
625
+ elif bg_source == BGSource.GREY:
626
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
627
+ elif bg_source == BGSource.LEFT:
628
+ gradient = np.linspace(224, 32, image_width)
629
+ image = np.tile(gradient, (image_height, 1))
630
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
631
+ elif bg_source == BGSource.RIGHT:
632
+ gradient = np.linspace(32, 224, image_width)
633
+ image = np.tile(gradient, (image_height, 1))
634
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
635
+ elif bg_source == BGSource.TOP:
636
+ gradient = np.linspace(224, 32, image_height)[:, None]
637
+ image = np.tile(gradient, (1, image_width))
638
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
639
+ elif bg_source == BGSource.BOTTOM:
640
+ gradient = np.linspace(32, 224, image_height)[:, None]
641
+ image = np.tile(gradient, (1, image_width))
642
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
643
+ else:
644
+ raise ValueError('Wrong background source!')
645
+
646
+ input_fg, matting = run_rmbg(input_fg)
647
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
648
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
649
+ final_results = results + extra_images
650
+
651
+ # Save the generated images
652
+ save_images(results, prefix="relight")
653
+
654
+ return results
655
+
656
+
657
+ quick_prompts = [
658
+ 'sunshine from window',
659
+ 'neon light, city',
660
+ 'sunset over sea',
661
+ 'golden time',
662
+ 'sci-fi RGB glowing, cyberpunk',
663
+ 'natural lighting',
664
+ 'warm atmosphere, at home, bedroom',
665
+ 'magic lit',
666
+ 'evil, gothic, Yharnam',
667
+ 'light and shadow',
668
+ 'shadow from window',
669
+ 'soft studio lighting',
670
+ 'home atmosphere, cozy bedroom illumination',
671
+ 'neon, Wong Kar-wai, warm'
672
+ ]
673
+ quick_prompts = [[x] for x in quick_prompts]
674
+
675
+
676
+ quick_subjects = [
677
+ 'modern sofa, high quality leather',
678
+ 'elegant dining table, polished wood',
679
+ 'luxurious bed, premium mattress',
680
+ 'minimalist office desk, clean design',
681
+ 'vintage wooden cabinet, antique finish',
682
+ ]
683
+ quick_subjects = [[x] for x in quick_subjects]
684
+
685
+
686
+ class BGSource(Enum):
687
+ UPLOAD = "Use Background Image"
688
+ UPLOAD_FLIP = "Use Flipped Background Image"
689
+ LEFT = "Left Light"
690
+ RIGHT = "Right Light"
691
+ TOP = "Top Light"
692
+ BOTTOM = "Bottom Light"
693
+ GREY = "Ambient"
694
+
695
+ # Add save function
696
+ def save_images(images, prefix="relight"):
697
+ # Create output directory if it doesn't exist
698
+ output_dir = Path("outputs")
699
+ output_dir.mkdir(exist_ok=True)
700
+
701
+ # Create timestamp for unique filenames
702
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
703
+
704
+ saved_paths = []
705
+ for i, img in enumerate(images):
706
+ if isinstance(img, np.ndarray):
707
+ # Convert to PIL Image if numpy array
708
+ img = Image.fromarray(img)
709
+
710
+ # Create filename with timestamp
711
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
712
+ filepath = output_dir / filename
713
+
714
+ # Save image
715
+ img.save(filepath)
716
+
717
+
718
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
719
+ return saved_paths
720
+
721
+
722
+ class MaskMover:
723
+ def __init__(self):
724
+ self.extracted_fg = None
725
+ self.original_fg = None # Store original foreground
726
+
727
+ def set_extracted_fg(self, fg_image):
728
+ """Store the extracted foreground with alpha channel"""
729
+ if isinstance(fg_image, np.ndarray):
730
+ self.extracted_fg = fg_image.copy()
731
+ self.original_fg = fg_image.copy()
732
+ else:
733
+ self.extracted_fg = np.array(fg_image)
734
+ self.original_fg = np.array(fg_image)
735
+ return self.extracted_fg
736
+
737
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
738
+ """Create composite with foreground at specified position"""
739
+ if self.original_fg is None or background is None:
740
+ return background
741
+
742
+ # Convert inputs to PIL Images
743
+ if isinstance(background, np.ndarray):
744
+ bg = Image.fromarray(background).convert('RGBA')
745
+ else:
746
+ bg = background.convert('RGBA')
747
+
748
+ if isinstance(self.original_fg, np.ndarray):
749
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
750
+ else:
751
+ fg = self.original_fg.convert('RGBA')
752
+
753
+ # Scale the foreground size
754
+ new_width = int(fg.width * scale)
755
+ new_height = int(fg.height * scale)
756
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
757
+
758
+ # Center the scaled foreground at the position
759
+ x = int(x_pos - new_width / 2)
760
+ y = int(y_pos - new_height / 2)
761
+
762
+ # Create composite
763
+ result = bg.copy()
764
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
765
+
766
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
767
+
768
+ def get_depth(image):
769
+ if image is None:
770
+ return None
771
+ # Convert from PIL/gradio format to cv2
772
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
773
+ # Get depth map
774
+ depth = model.infer_image(raw_img) # HxW raw depth map
775
+ # Normalize depth for visualization
776
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
777
+ # Convert to RGB for display
778
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
779
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
780
+ return Image.fromarray(depth_colored)
781
+
782
+
783
+ from PIL import Image
784
+
785
+ def compress_image(image):
786
+ # Convert Gradio image (numpy array) to PIL Image
787
+ img = Image.fromarray(image)
788
+
789
+ # Resize image if dimensions are too large
790
+ max_size = 1024 # Maximum dimension size
791
+ if img.width > max_size or img.height > max_size:
792
+ ratio = min(max_size/img.width, max_size/img.height)
793
+ new_size = (int(img.width * ratio), int(img.height * ratio))
794
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
795
+
796
+ quality = 95 # Start with high quality
797
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
798
+
799
+ # Check file size and adjust quality if necessary
800
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
801
+ quality -= 5 # Decrease quality
802
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
803
+ if quality < 20: # Prevent quality from going too low
804
+ break
805
+
806
+ # Convert back to numpy array for Gradio
807
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
808
+ return compressed_img
809
+
810
+ @spaces.GPU(duration=60)
811
+ @torch.inference_mode()
812
+ def process_image(input_image, input_text):
813
+ """Main processing function for the Gradio interface"""
814
+
815
+ # Initialize configs
816
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
817
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
818
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
819
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
820
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
821
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
822
+
823
+ # Initialize DDS client
824
+ config = Config(API_TOKEN)
825
+ client = Client(config)
826
+
827
+ # Process classes from text prompt
828
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
829
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
830
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
831
+
832
+ # Save input image to temp file and get URL
833
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
834
+ cv2.imwrite(tmpfile.name, input_image)
835
+ image_url = client.upload_file(tmpfile.name)
836
+ os.remove(tmpfile.name)
837
+
838
+ # Run DINO-X detection
839
+ task = DinoxTask(
840
+ image_url=image_url,
841
+ prompts=[TextPrompt(text=input_text)]
842
+ )
843
+ client.run_task(task)
844
+ result = task.result
845
+ objects = result.objects
846
+
847
+ # Process detection results
848
+ input_boxes = []
849
+ confidences = []
850
+ class_names = []
851
+ class_ids = []
852
+
853
+ for obj in objects:
854
+ input_boxes.append(obj.bbox)
855
+ confidences.append(obj.score)
856
+ cls_name = obj.category.lower().strip()
857
+ class_names.append(cls_name)
858
+ class_ids.append(class_name_to_id[cls_name])
859
+
860
+ input_boxes = np.array(input_boxes)
861
+ class_ids = np.array(class_ids)
862
+
863
+ # Initialize SAM2
864
+ torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
865
+ if torch.cuda.get_device_properties(0).major >= 8:
866
+ torch.backends.cuda.matmul.allow_tf32 = True
867
+ torch.backends.cudnn.allow_tf32 = True
868
+
869
+ sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
870
+ sam2_predictor = SAM2ImagePredictor(sam2_model)
871
+ sam2_predictor.set_image(input_image)
872
+
873
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
874
+
875
+
876
+ # Get masks from SAM2
877
+ masks, scores, logits = sam2_predictor.predict(
878
+ point_coords=None,
879
+ point_labels=None,
880
+ box=input_boxes,
881
+ multimask_output=False,
882
+ )
883
+ if masks.ndim == 4:
884
+ masks = masks.squeeze(1)
885
+
886
+ # Create visualization
887
+ labels = [f"{class_name} {confidence:.2f}"
888
+ for class_name, confidence in zip(class_names, confidences)]
889
+
890
+ detections = sv.Detections(
891
+ xyxy=input_boxes,
892
+ mask=masks.astype(bool),
893
+ class_id=class_ids
894
+ )
895
+
896
+ box_annotator = sv.BoxAnnotator()
897
+ label_annotator = sv.LabelAnnotator()
898
+ mask_annotator = sv.MaskAnnotator()
899
+
900
+ annotated_frame = input_image.copy()
901
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
902
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
903
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
904
+
905
+ # Create transparent mask for first detected object
906
+ if len(detections) > 0:
907
+ # Get first mask
908
+ first_mask = detections.mask[0]
909
+
910
+ # Get original RGB image
911
+ img = input_image.copy()
912
+ H, W, C = img.shape
913
+
914
+ # Create RGBA image
915
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
916
+ alpha[first_mask] = 255
917
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
918
+
919
+ # Crop to mask bounds to minimize image size
920
+ y_indices, x_indices = np.where(first_mask)
921
+ y_min, y_max = y_indices.min(), y_indices.max()
922
+ x_min, x_max = x_indices.min(), x_indices.max()
923
+
924
+ # Crop the RGBA image
925
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
926
+
927
+ # Set extracted foreground for mask mover
928
+ mask_mover.set_extracted_fg(cropped_rgba)
929
+
930
+ return annotated_frame, cropped_rgba, gr.update(visible=True), gr.update(visible=True)
931
+
932
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
933
+
934
+
935
+ block = gr.Blocks().queue()
936
+ with block:
937
+ with gr.Tab("Text"):
938
+ with gr.Row():
939
+ gr.Markdown("## Product Placement from Text")
940
+ with gr.Row():
941
+ with gr.Column():
942
+ with gr.Row():
943
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
944
+ with gr.Row():
945
+ with gr.Group():
946
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
947
+ text_prompt = gr.Textbox(
948
+ label="Text Prompt",
949
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .')",
950
+ value="couch . table ."
951
+ )
952
+ extract_button = gr.Button(value="(Option 2) Remove Background")
953
+ with gr.Row():
954
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
955
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
956
+
957
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
958
+ with gr.Group():
959
+ prompt = gr.Textbox(label="Prompt")
960
+ bg_source = gr.Radio(choices=[e.value for e in BGSource],
961
+ value=BGSource.GREY.value,
962
+ label="Lighting Preference (Initial Latent)", type='value')
963
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
964
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
965
+ relight_button = gr.Button(value="Relight")
966
+
967
+ with gr.Group():
968
+ with gr.Row():
969
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
970
+ seed = gr.Number(label="Seed", value=12345, precision=0)
971
+
972
+ with gr.Row():
973
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
974
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
975
+
976
+ with gr.Accordion("Advanced options", open=False):
977
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
978
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01)
979
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
980
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
981
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
982
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
983
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality')
984
+ with gr.Column():
985
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
986
+ with gr.Row():
987
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
988
+ # gr.Examples(
989
+ # fn=lambda *args: ([args[-1]], None),
990
+ # examples=db_examples.foreground_conditioned_examples,
991
+ # inputs=[
992
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
993
+ # ],
994
+ # outputs=[result_gallery, output_bg],
995
+ # run_on_click=True, examples_per_page=1024
996
+ # )
997
+ ips = [input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
998
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[extracted_fg, result_gallery])
999
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1000
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1001
+ find_objects_button.click(
1002
+ fn=process_image,
1003
+ inputs=[input_fg, text_prompt],
1004
+ outputs=[extracted_objects, extracted_fg]
1005
+ )
1006
+
1007
+ with gr.Tab("Background", visible=False):
1008
+ # empty cache
1009
+
1010
+ mask_mover = MaskMover()
1011
+
1012
+ # with torch.no_grad():
1013
+ # # Update the input channels to 12
1014
+ # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) # Changed from 8 to 12
1015
+ # new_conv_in.weight.zero_()
1016
+ # new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
1017
+ # new_conv_in.bias = unet.conv_in.bias
1018
+ # unet.conv_in = new_conv_in
1019
+ with gr.Row():
1020
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
1021
+ gr.Markdown("💾 Generated images are automatically saved to 'outputs' folder")
1022
+
1023
+ with gr.Row():
1024
+ with gr.Column():
1025
+ # Step 1: Input and Extract
1026
+ with gr.Row():
1027
+ with gr.Group():
1028
+ gr.Markdown("### Step 1: Extract Foreground")
1029
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
1030
+ # find_objects_button = gr.Button(value="Find Objects")
1031
+ extract_button = gr.Button(value="Remove Background")
1032
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1033
+
1034
+ with gr.Row():
1035
+ # Step 2: Background and Position
1036
+ with gr.Group():
1037
+ gr.Markdown("### Step 2: Position on Background")
1038
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
1039
+
1040
+ with gr.Row():
1041
+ x_slider = gr.Slider(
1042
+ minimum=0,
1043
+ maximum=1000,
1044
+ label="X Position",
1045
+ value=500,
1046
+ visible=False
1047
+ )
1048
+ y_slider = gr.Slider(
1049
+ minimum=0,
1050
+ maximum=1000,
1051
+ label="Y Position",
1052
+ value=500,
1053
+ visible=False
1054
+ )
1055
+ fg_scale_slider = gr.Slider(
1056
+ label="Foreground Scale",
1057
+ minimum=0.01,
1058
+ maximum=3.0,
1059
+ value=1.0,
1060
+ step=0.01
1061
+ )
1062
+
1063
+ editor = gr.ImageEditor(
1064
+ type="numpy",
1065
+ label="Position Foreground",
1066
+ height=480,
1067
+ visible=False
1068
+ )
1069
+ get_depth_button = gr.Button(value="Get Depth")
1070
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
1071
+
1072
+ # Step 3: Relighting Options
1073
+ with gr.Group():
1074
+ gr.Markdown("### Step 3: Relighting Settings")
1075
+ prompt = gr.Textbox(label="Prompt")
1076
+ bg_source = gr.Radio(
1077
+ choices=[e.value for e in BGSource],
1078
+ value=BGSource.UPLOAD.value,
1079
+ label="Background Source",
1080
+ type='value'
1081
+ )
1082
+
1083
+ example_prompts = gr.Dataset(
1084
+ samples=quick_prompts,
1085
+ label='Prompt Quick List',
1086
+ components=[prompt]
1087
+ )
1088
+ # bg_gallery = gr.Gallery(
1089
+ # height=450,
1090
+ # label='Background Quick List',
1091
+ # value=db_examples.bg_samples,
1092
+ # columns=5,
1093
+ # allow_preview=False
1094
+ # )
1095
+ relight_button_bg = gr.Button(value="Relight")
1096
+
1097
+ # Additional settings
1098
+ with gr.Group():
1099
+ with gr.Row():
1100
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1101
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1102
+ with gr.Row():
1103
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1104
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1105
+
1106
+ with gr.Accordion("Advanced options", open=False):
1107
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1108
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
1109
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
1110
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
1111
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
1112
+ n_prompt = gr.Textbox(
1113
+ label="Negative Prompt",
1114
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
1115
+ )
1116
+
1117
+ with gr.Column():
1118
+ result_gallery = gr.Image(height=832, label='Outputs')
1119
+
1120
+ def extract_foreground(image):
1121
+ if image is None:
1122
+ return None, gr.update(visible=True), gr.update(visible=True)
1123
+ result, rgba = run_rmbg(image)
1124
+ mask_mover.set_extracted_fg(rgba)
1125
+
1126
+ return result, gr.update(visible=True), gr.update(visible=True)
1127
+
1128
+
1129
+ original_bg = None
1130
+
1131
+ extract_button.click(
1132
+ fn=extract_foreground,
1133
+ inputs=[input_image],
1134
+ outputs=[extracted_fg, x_slider, y_slider]
1135
+ )
1136
+
1137
+ find_objects_button.click(
1138
+ fn=process_image,
1139
+ inputs=[input_image, text_prompt],
1140
+ outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
1141
+ )
1142
+
1143
+ get_depth_button.click(
1144
+ fn=get_depth,
1145
+ inputs=[input_bg],
1146
+ outputs=[depth_image]
1147
+ )
1148
+
1149
+ # def update_position(background, x_pos, y_pos, scale):
1150
+ # """Update composite when position changes"""
1151
+ # global original_bg
1152
+ # if background is None:
1153
+ # return None
1154
+
1155
+ # if original_bg is None:
1156
+ # original_bg = background.copy()
1157
+
1158
+ # # Convert string values to float
1159
+ # x_pos = float(x_pos)
1160
+ # y_pos = float(y_pos)
1161
+ # scale = float(scale)
1162
+
1163
+ # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
1164
+
1165
+ class BackgroundManager:
1166
+ def __init__(self):
1167
+ self.original_bg = None
1168
+
1169
+ def update_position(self, background, x_pos, y_pos, scale):
1170
+ """Update composite when position changes"""
1171
+ if background is None:
1172
+ return None
1173
+
1174
+ if self.original_bg is None:
1175
+ self.original_bg = background.copy()
1176
+
1177
+ # Convert string values to float
1178
+ x_pos = float(x_pos)
1179
+ y_pos = float(y_pos)
1180
+ scale = float(scale)
1181
+
1182
+ return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
1183
+
1184
+ # Create an instance of BackgroundManager
1185
+ bg_manager = BackgroundManager()
1186
+
1187
+
1188
+ x_slider.change(
1189
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1190
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1191
+ outputs=[input_bg]
1192
+ )
1193
+
1194
+ y_slider.change(
1195
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1196
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1197
+ outputs=[input_bg]
1198
+ )
1199
+
1200
+ fg_scale_slider.change(
1201
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1202
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1203
+ outputs=[input_bg]
1204
+ )
1205
+
1206
+ # Update inputs list to include fg_scale_slider
1207
+
1208
+ def process_relight_with_position(*args):
1209
+ if mask_mover.extracted_fg is None:
1210
+ gr.Warning("Please extract foreground first")
1211
+ return None
1212
+
1213
+ background = args[1] # Get background image
1214
+ x_pos = float(args[-3]) # x_slider value
1215
+ y_pos = float(args[-2]) # y_slider value
1216
+ scale = float(args[-1]) # fg_scale_slider value
1217
+
1218
+ # Get original foreground size after scaling
1219
+ fg = Image.fromarray(mask_mover.original_fg)
1220
+ new_width = int(fg.width * scale)
1221
+ new_height = int(fg.height * scale)
1222
+
1223
+ # Calculate crop region around foreground position
1224
+ crop_x = int(x_pos - new_width/2)
1225
+ crop_y = int(y_pos - new_height/2)
1226
+ crop_width = new_width
1227
+ crop_height = new_height
1228
+
1229
+ # Add padding for context (20% extra on each side)
1230
+ padding = 0.2
1231
+ crop_x = int(crop_x - crop_width * padding)
1232
+ crop_y = int(crop_y - crop_height * padding)
1233
+ crop_width = int(crop_width * (1 + 2 * padding))
1234
+ crop_height = int(crop_height * (1 + 2 * padding))
1235
+
1236
+ # Ensure crop dimensions are multiples of 8
1237
+ crop_width = ((crop_width + 7) // 8) * 8
1238
+ crop_height = ((crop_height + 7) // 8) * 8
1239
+
1240
+ # Ensure crop region is within image bounds
1241
+ bg_height, bg_width = background.shape[:2]
1242
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
1243
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
1244
+
1245
+ # Get actual crop dimensions after boundary check
1246
+ crop_width = min(crop_width, bg_width - crop_x)
1247
+ crop_height = min(crop_height, bg_height - crop_y)
1248
+
1249
+ # Ensure dimensions are multiples of 8 again
1250
+ crop_width = (crop_width // 8) * 8
1251
+ crop_height = (crop_height // 8) * 8
1252
+
1253
+ # Crop region from background
1254
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1255
+
1256
+ # Create composite in cropped region
1257
+ fg_local_x = int(new_width/2 + crop_width*padding)
1258
+ fg_local_y = int(new_height/2 + crop_height*padding)
1259
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1260
+
1261
+ # Process the cropped region
1262
+ crop_args = list(args)
1263
+ crop_args[0] = cropped_composite
1264
+ crop_args[1] = crop_region
1265
+ crop_args[3] = crop_width
1266
+ crop_args[4] = crop_height
1267
+ crop_args = crop_args[:-3] # Remove position and scale arguments
1268
+
1269
+ # Get relit result
1270
+ relit_crop = process_relight_bg(*crop_args)[0]
1271
+
1272
+ # Resize relit result to match crop dimensions if needed
1273
+ if relit_crop.shape[:2] != (crop_height, crop_width):
1274
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1275
+
1276
+ # Place relit crop back into original background
1277
+ result = background.copy()
1278
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1279
+
1280
+ return result
1281
+
1282
+ ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1283
+
1284
+ # Update button click events with new inputs list
1285
+ relight_button_bg.click(
1286
+ fn=process_relight_with_position,
1287
+ inputs=ips_bg,
1288
+ outputs=[result_gallery]
1289
+ )
1290
+
1291
+
1292
+ example_prompts.click(
1293
+ fn=lambda x: x[0],
1294
+ inputs=example_prompts,
1295
+ outputs=prompt,
1296
+ show_progress=False,
1297
+ queue=False
1298
+ )
1299
+
1300
+
1301
+ block.launch(server_name='0.0.0.0', share=False)