Ashoka74 commited on
Commit
4027eb1
1 Parent(s): b5b5693

Create app_2.py

Browse files
Files changed (1) hide show
  1. app_2.py +1543 -0
app_2.py ADDED
@@ -0,0 +1,1543 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import random
4
+
5
+ import os
6
+ import math
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import safetensors.torch as sf
11
+ import datetime
12
+ from pathlib import Path
13
+ from io import BytesIO
14
+
15
+ from PIL import Image
16
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
17
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
18
+ from diffusers.models.attention_processor import AttnProcessor2_0
19
+ from transformers import CLIPTextModel, CLIPTokenizer
20
+ import dds_cloudapi_sdk
21
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
22
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
23
+ from dds_cloudapi_sdk.tasks import DetectionTarget
24
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
25
+
26
+ from enum import Enum
27
+ from torch.hub import download_url_to_file
28
+ import tempfile
29
+
30
+ from sam2.build_sam import build_sam2
31
+
32
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
33
+ import cv2
34
+
35
+ from transformers import AutoModelForImageSegmentation
36
+ from inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
37
+
38
+
39
+ from typing import Optional
40
+
41
+ from depth_anything_v2.dpt import DepthAnythingV2
42
+
43
+ import httpx
44
+
45
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
46
+ NUM_VIEWS = 6
47
+ HEIGHT = 768
48
+ WIDTH = 768
49
+ MAX_SEED = np.iinfo(np.int32).max
50
+
51
+
52
+ import supervision as sv
53
+ import torch
54
+ from PIL import Image
55
+
56
+
57
+
58
+
59
+ # Load
60
+
61
+ # Model paths
62
+ model_path = './models/iclight_sd15_fc.safetensors'
63
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
64
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
65
+ model_path4 = './checkpoints/config.json'
66
+ model_path5 = './checkpoints/preprocessor_config.json'
67
+ model_path6 = './configs/sam2_hiera_l.yaml'
68
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
69
+
70
+ # Base URL for the repository
71
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
72
+
73
+ # Model URLs
74
+ model_urls = {
75
+ model_path: 'iclight_sd15_fc.safetensors',
76
+ model_path2: 'depth_anything_v2_vits.pth',
77
+ model_path3: 'sam2_hiera_large.pt',
78
+ model_path4: 'config.json',
79
+ model_path5: 'preprocessor_config.json',
80
+ model_path6: 'sam2_hiera_l.yaml',
81
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
82
+ }
83
+
84
+ # Ensure directories exist
85
+ def ensure_directories():
86
+ for path in model_urls.keys():
87
+ os.makedirs(os.path.dirname(path), exist_ok=True)
88
+
89
+ # Download models
90
+ def download_models():
91
+ for local_path, filename in model_urls.items():
92
+ if not os.path.exists(local_path):
93
+ try:
94
+ url = f"{BASE_URL}{filename}"
95
+ print(f"Downloading {filename}")
96
+ download_url_to_file(url, local_path)
97
+ print(f"Successfully downloaded {filename}")
98
+ except Exception as e:
99
+ print(f"Error downloading {filename}: {e}")
100
+
101
+ ensure_directories()
102
+
103
+ download_models()
104
+
105
+ pipe = prepare_pipeline(
106
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
107
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
108
+ unet_model=None,
109
+ lora_model=None,
110
+ adapter_path="huanngzh/mv-adapter",
111
+ scheduler=None,
112
+ num_views=NUM_VIEWS,
113
+ device=device,
114
+ dtype=dtype,
115
+ )
116
+
117
+ @spaces.GPU()
118
+ def infer(
119
+ prompt,
120
+ image,
121
+ do_rembg=False,
122
+ seed=42,
123
+ randomize_seed=False,
124
+ guidance_scale=3.0,
125
+ num_inference_steps=50,
126
+ reference_conditioning_scale=1.0,
127
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
128
+ progress=gr.Progress(track_tqdm=True),
129
+ ):
130
+ # if do_rembg:
131
+ # remove_bg_fn = lambda x: remove_bg(x, birefnet, transform_image, device)
132
+ # else:
133
+ # remove_bg_fn = None
134
+ if randomize_seed:
135
+ seed = random.randint(0, MAX_SEED)
136
+ images, preprocessed_image = run_pipeline(
137
+ pipe,
138
+ num_views=NUM_VIEWS,
139
+ text=prompt,
140
+ image=image,
141
+ height=HEIGHT,
142
+ width=WIDTH,
143
+ num_inference_steps=num_inference_steps,
144
+ guidance_scale=guidance_scale,
145
+ seed=seed,
146
+ remove_bg_fn=None,
147
+ reference_conditioning_scale=reference_conditioning_scale,
148
+ negative_prompt=negative_prompt,
149
+ device=device,
150
+ )
151
+ return images
152
+
153
+
154
+ try:
155
+ import xformers
156
+ import xformers.ops
157
+ XFORMERS_AVAILABLE = True
158
+ print("xformers is available - Using memory efficient attention")
159
+ except ImportError:
160
+ XFORMERS_AVAILABLE = False
161
+ print("xformers not available - Using default attention")
162
+
163
+ # Memory optimizations for RTX 2070
164
+ torch.backends.cudnn.benchmark = True
165
+ if torch.cuda.is_available():
166
+ torch.backends.cuda.matmul.allow_tf32 = True
167
+ torch.backends.cudnn.allow_tf32 = True
168
+ # Set a smaller attention slice size for RTX 2070
169
+ torch.backends.cuda.max_split_size_mb = 512
170
+ device = torch.device('cuda')
171
+ else:
172
+ device = torch.device('cpu')
173
+
174
+ # 'stablediffusionapi/realistic-vision-v51'
175
+ # 'runwayml/stable-diffusion-v1-5'
176
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
177
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
178
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
179
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
180
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
181
+ # Load model directly
182
+ from transformers import AutoModelForImageSegmentation
183
+ rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
184
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
185
+
186
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
187
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
188
+ model = model.to(device)
189
+ model.eval()
190
+
191
+ # Change UNet
192
+
193
+
194
+ with torch.no_grad():
195
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
196
+ new_conv_in.weight.zero_()
197
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
198
+ new_conv_in.bias = unet.conv_in.bias
199
+ unet.conv_in = new_conv_in
200
+
201
+
202
+ unet_original_forward = unet.forward
203
+
204
+
205
+ def enable_efficient_attention():
206
+ if XFORMERS_AVAILABLE:
207
+ try:
208
+ # RTX 2070 specific settings
209
+ unet.set_use_memory_efficient_attention_xformers(True)
210
+ vae.set_use_memory_efficient_attention_xformers(True)
211
+ print("Enabled xformers memory efficient attention")
212
+ except Exception as e:
213
+ print(f"Xformers error: {e}")
214
+ print("Falling back to sliced attention")
215
+ # Use sliced attention for RTX 2070
216
+ # unet.set_attention_slice_size(4)
217
+ # vae.set_attention_slice_size(4)
218
+ unet.set_attn_processor(AttnProcessor2_0())
219
+ vae.set_attn_processor(AttnProcessor2_0())
220
+ else:
221
+ # Fallback for when xformers is not available
222
+ print("Using sliced attention")
223
+ # unet.set_attention_slice_size(4)
224
+ # vae.set_attention_slice_size(4)
225
+ unet.set_attn_processor(AttnProcessor2_0())
226
+ vae.set_attn_processor(AttnProcessor2_0())
227
+
228
+ # Add memory clearing function
229
+ def clear_memory():
230
+ if torch.cuda.is_available():
231
+ torch.cuda.empty_cache()
232
+ torch.cuda.synchronize()
233
+
234
+ # Enable efficient attention
235
+ enable_efficient_attention()
236
+
237
+
238
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
239
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
240
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
241
+ new_sample = torch.cat([sample, c_concat], dim=1)
242
+ kwargs['cross_attention_kwargs'] = {}
243
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
244
+
245
+
246
+ unet.forward = hooked_unet_forward
247
+
248
+
249
+
250
+
251
+ sd_offset = sf.load_file(model_path)
252
+ sd_origin = unet.state_dict()
253
+ keys = sd_origin.keys()
254
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
255
+ unet.load_state_dict(sd_merged, strict=True)
256
+ del sd_offset, sd_origin, sd_merged, keys
257
+
258
+ # Device
259
+
260
+ # device = torch.device('cuda')
261
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
262
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
263
+ # unet = unet.to(device=device, dtype=torch.float16)
264
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
265
+
266
+
267
+ # Device and dtype setup
268
+ device = torch.device('cuda')
269
+ dtype = torch.float16 # RTX 2070 works well with float16
270
+
271
+ # Memory optimizations for RTX 2070
272
+ torch.backends.cudnn.benchmark = True
273
+ if torch.cuda.is_available():
274
+ torch.backends.cuda.matmul.allow_tf32 = True
275
+ torch.backends.cudnn.allow_tf32 = True
276
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
277
+ torch.backends.cuda.max_split_size_mb = 128
278
+
279
+ # Move models to device with consistent dtype
280
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
281
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
282
+ unet = unet.to(device=device, dtype=dtype)
283
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
284
+
285
+
286
+ ddim_scheduler = DDIMScheduler(
287
+ num_train_timesteps=1000,
288
+ beta_start=0.00085,
289
+ beta_end=0.012,
290
+ beta_schedule="scaled_linear",
291
+ clip_sample=False,
292
+ set_alpha_to_one=False,
293
+ steps_offset=1,
294
+ )
295
+
296
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
297
+ num_train_timesteps=1000,
298
+ beta_start=0.00085,
299
+ beta_end=0.012,
300
+ steps_offset=1
301
+ )
302
+
303
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
304
+ num_train_timesteps=1000,
305
+ beta_start=0.00085,
306
+ beta_end=0.012,
307
+ algorithm_type="sde-dpmsolver++",
308
+ use_karras_sigmas=True,
309
+ steps_offset=1
310
+ )
311
+
312
+ # Pipelines
313
+
314
+ t2i_pipe = StableDiffusionPipeline(
315
+ vae=vae,
316
+ text_encoder=text_encoder,
317
+ tokenizer=tokenizer,
318
+ unet=unet,
319
+ scheduler=dpmpp_2m_sde_karras_scheduler,
320
+ safety_checker=None,
321
+ requires_safety_checker=False,
322
+ feature_extractor=None,
323
+ image_encoder=None
324
+ )
325
+
326
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
327
+ vae=vae,
328
+ text_encoder=text_encoder,
329
+ tokenizer=tokenizer,
330
+ unet=unet,
331
+ scheduler=dpmpp_2m_sde_karras_scheduler,
332
+ safety_checker=None,
333
+ requires_safety_checker=False,
334
+ feature_extractor=None,
335
+ image_encoder=None
336
+ )
337
+
338
+
339
+ @torch.inference_mode()
340
+ def encode_prompt_inner(txt: str):
341
+ max_length = tokenizer.model_max_length
342
+ chunk_length = tokenizer.model_max_length - 2
343
+ id_start = tokenizer.bos_token_id
344
+ id_end = tokenizer.eos_token_id
345
+ id_pad = id_end
346
+
347
+ def pad(x, p, i):
348
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
349
+
350
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
351
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
352
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
353
+
354
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
355
+ conds = text_encoder(token_ids).last_hidden_state
356
+
357
+ return conds
358
+
359
+
360
+ @torch.inference_mode()
361
+ def encode_prompt_pair(positive_prompt, negative_prompt):
362
+ c = encode_prompt_inner(positive_prompt)
363
+ uc = encode_prompt_inner(negative_prompt)
364
+
365
+ c_len = float(len(c))
366
+ uc_len = float(len(uc))
367
+ max_count = max(c_len, uc_len)
368
+ c_repeat = int(math.ceil(max_count / c_len))
369
+ uc_repeat = int(math.ceil(max_count / uc_len))
370
+ max_chunk = max(len(c), len(uc))
371
+
372
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
373
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
374
+
375
+ c = torch.cat([p[None, ...] for p in c], dim=1)
376
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
377
+
378
+ return c, uc
379
+
380
+ @spaces.GPU(duration=60)
381
+ @torch.inference_mode()
382
+ def pytorch2numpy(imgs, quant=True):
383
+ results = []
384
+ for x in imgs:
385
+ y = x.movedim(0, -1)
386
+
387
+ if quant:
388
+ y = y * 127.5 + 127.5
389
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
390
+ else:
391
+ y = y * 0.5 + 0.5
392
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
393
+
394
+ results.append(y)
395
+ return results
396
+
397
+ @spaces.GPU(duration=60)
398
+ @torch.inference_mode()
399
+ def numpy2pytorch(imgs):
400
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
401
+ h = h.movedim(-1, 1)
402
+ return h
403
+
404
+
405
+ def resize_and_center_crop(image, target_width, target_height):
406
+ pil_image = Image.fromarray(image)
407
+ original_width, original_height = pil_image.size
408
+ scale_factor = max(target_width / original_width, target_height / original_height)
409
+ resized_width = int(round(original_width * scale_factor))
410
+ resized_height = int(round(original_height * scale_factor))
411
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
412
+ left = (resized_width - target_width) / 2
413
+ top = (resized_height - target_height) / 2
414
+ right = (resized_width + target_width) / 2
415
+ bottom = (resized_height + target_height) / 2
416
+ cropped_image = resized_image.crop((left, top, right, bottom))
417
+ return np.array(cropped_image)
418
+
419
+
420
+ def resize_without_crop(image, target_width, target_height):
421
+ pil_image = Image.fromarray(image)
422
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
423
+ return np.array(resized_image)
424
+
425
+ @spaces.GPU(duration=60)
426
+ @torch.inference_mode()
427
+ def run_rmbg(img, sigma=0.0):
428
+ # Convert RGBA to RGB if needed
429
+ if img.shape[-1] == 4:
430
+ # Use white background for alpha composition
431
+ alpha = img[..., 3:] / 255.0
432
+ rgb = img[..., :3]
433
+ white_bg = np.ones_like(rgb) * 255
434
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
435
+
436
+ H, W, C = img.shape
437
+ assert C == 3
438
+ k = (256.0 / float(H * W)) ** 0.5
439
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
440
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
441
+ alpha = rmbg(feed)[0][0]
442
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
443
+ alpha = alpha.movedim(1, -1)[0]
444
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
445
+
446
+ # Create RGBA image
447
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
448
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
449
+ return result.clip(0, 255).astype(np.uint8), rgba
450
+
451
+ @spaces.GPU(duration=60)
452
+ @torch.inference_mode()
453
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
454
+ clear_memory()
455
+
456
+ # Get input dimensions
457
+ input_height, input_width = input_fg.shape[:2]
458
+
459
+ bg_source = BGSource(bg_source)
460
+
461
+
462
+ if bg_source == BGSource.UPLOAD:
463
+ pass
464
+ elif bg_source == BGSource.UPLOAD_FLIP:
465
+ input_bg = np.fliplr(input_bg)
466
+ if bg_source == BGSource.GREY:
467
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
468
+ elif bg_source == BGSource.LEFT:
469
+ gradient = np.linspace(255, 0, input_width)
470
+ image = np.tile(gradient, (input_height, 1))
471
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
472
+ elif bg_source == BGSource.RIGHT:
473
+ gradient = np.linspace(0, 255, input_width)
474
+ image = np.tile(gradient, (input_height, 1))
475
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
476
+ elif bg_source == BGSource.TOP:
477
+ gradient = np.linspace(255, 0, input_height)[:, None]
478
+ image = np.tile(gradient, (1, input_width))
479
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
480
+ elif bg_source == BGSource.BOTTOM:
481
+ gradient = np.linspace(0, 255, input_height)[:, None]
482
+ image = np.tile(gradient, (1, input_width))
483
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
484
+ else:
485
+ raise 'Wrong initial latent!'
486
+
487
+ rng = torch.Generator(device=device).manual_seed(int(seed))
488
+
489
+ # Use input dimensions directly
490
+ fg = resize_without_crop(input_fg, input_width, input_height)
491
+
492
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
493
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
494
+
495
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
496
+
497
+ if input_bg is None:
498
+ latents = t2i_pipe(
499
+ prompt_embeds=conds,
500
+ negative_prompt_embeds=unconds,
501
+ width=input_width,
502
+ height=input_height,
503
+ num_inference_steps=steps,
504
+ num_images_per_prompt=num_samples,
505
+ generator=rng,
506
+ output_type='latent',
507
+ guidance_scale=cfg,
508
+ cross_attention_kwargs={'concat_conds': concat_conds},
509
+ ).images.to(vae.dtype) / vae.config.scaling_factor
510
+ else:
511
+ bg = resize_without_crop(input_bg, input_width, input_height)
512
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
513
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
514
+ latents = i2i_pipe(
515
+ image=bg_latent,
516
+ strength=lowres_denoise,
517
+ prompt_embeds=conds,
518
+ negative_prompt_embeds=unconds,
519
+ width=input_width,
520
+ height=input_height,
521
+ num_inference_steps=int(round(steps / lowres_denoise)),
522
+ num_images_per_prompt=num_samples,
523
+ generator=rng,
524
+ output_type='latent',
525
+ guidance_scale=cfg,
526
+ cross_attention_kwargs={'concat_conds': concat_conds},
527
+ ).images.to(vae.dtype) / vae.config.scaling_factor
528
+
529
+ pixels = vae.decode(latents).sample
530
+ pixels = pytorch2numpy(pixels)
531
+ pixels = [resize_without_crop(
532
+ image=p,
533
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
534
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
535
+ for p in pixels]
536
+
537
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
538
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
539
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
540
+
541
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
542
+
543
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
544
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
545
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
546
+
547
+ latents = i2i_pipe(
548
+ image=latents,
549
+ strength=highres_denoise,
550
+ prompt_embeds=conds,
551
+ negative_prompt_embeds=unconds,
552
+ width=highres_width,
553
+ height=highres_height,
554
+ num_inference_steps=int(round(steps / highres_denoise)),
555
+ num_images_per_prompt=num_samples,
556
+ generator=rng,
557
+ output_type='latent',
558
+ guidance_scale=cfg,
559
+ cross_attention_kwargs={'concat_conds': concat_conds},
560
+ ).images.to(vae.dtype) / vae.config.scaling_factor
561
+
562
+ pixels = vae.decode(latents).sample
563
+ pixels = pytorch2numpy(pixels)
564
+
565
+ # Resize back to input dimensions
566
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
567
+ pixels = np.stack(pixels)
568
+
569
+ return pixels
570
+
571
+ def extract_foreground(image):
572
+ if image is None:
573
+ return None, gr.update(visible=True), gr.update(visible=True)
574
+ result, rgba = run_rmbg(image)
575
+ mask_mover.set_extracted_fg(rgba)
576
+
577
+ return result, gr.update(visible=True), gr.update(visible=True)
578
+
579
+ @torch.inference_mode()
580
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
581
+ clear_memory()
582
+ bg_source = BGSource(bg_source)
583
+
584
+ if bg_source == BGSource.UPLOAD:
585
+ pass
586
+ elif bg_source == BGSource.UPLOAD_FLIP:
587
+ input_bg = np.fliplr(input_bg)
588
+ elif bg_source == BGSource.GREY:
589
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
590
+ elif bg_source == BGSource.LEFT:
591
+ gradient = np.linspace(224, 32, image_width)
592
+ image = np.tile(gradient, (image_height, 1))
593
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
594
+ elif bg_source == BGSource.RIGHT:
595
+ gradient = np.linspace(32, 224, image_width)
596
+ image = np.tile(gradient, (image_height, 1))
597
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
598
+ elif bg_source == BGSource.TOP:
599
+ gradient = np.linspace(224, 32, image_height)[:, None]
600
+ image = np.tile(gradient, (1, image_width))
601
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
602
+ elif bg_source == BGSource.BOTTOM:
603
+ gradient = np.linspace(32, 224, image_height)[:, None]
604
+ image = np.tile(gradient, (1, image_width))
605
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
606
+ else:
607
+ raise 'Wrong background source!'
608
+
609
+ rng = torch.Generator(device=device).manual_seed(seed)
610
+
611
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
612
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
613
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
614
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
615
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
616
+
617
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
618
+
619
+ latents = t2i_pipe(
620
+ prompt_embeds=conds,
621
+ negative_prompt_embeds=unconds,
622
+ width=image_width,
623
+ height=image_height,
624
+ num_inference_steps=steps,
625
+ num_images_per_prompt=num_samples,
626
+ generator=rng,
627
+ output_type='latent',
628
+ guidance_scale=cfg,
629
+ cross_attention_kwargs={'concat_conds': concat_conds},
630
+ ).images.to(vae.dtype) / vae.config.scaling_factor
631
+
632
+ pixels = vae.decode(latents).sample
633
+ pixels = pytorch2numpy(pixels)
634
+ pixels = [resize_without_crop(
635
+ image=p,
636
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
637
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
638
+ for p in pixels]
639
+
640
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
641
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
642
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
643
+
644
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
645
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
646
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
647
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
648
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
649
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
650
+
651
+ latents = i2i_pipe(
652
+ image=latents,
653
+ strength=highres_denoise,
654
+ prompt_embeds=conds,
655
+ negative_prompt_embeds=unconds,
656
+ width=image_width,
657
+ height=image_height,
658
+ num_inference_steps=int(round(steps / highres_denoise)),
659
+ num_images_per_prompt=num_samples,
660
+ generator=rng,
661
+ output_type='latent',
662
+ guidance_scale=cfg,
663
+ cross_attention_kwargs={'concat_conds': concat_conds},
664
+ ).images.to(vae.dtype) / vae.config.scaling_factor
665
+
666
+ pixels = vae.decode(latents).sample
667
+ pixels = pytorch2numpy(pixels, quant=False)
668
+
669
+ clear_memory()
670
+ return pixels, [fg, bg]
671
+
672
+
673
+ @torch.inference_mode()
674
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
675
+ #input_fg, matting = run_rmbg(input_fg)
676
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
677
+ return results
678
+
679
+
680
+
681
+ @torch.inference_mode()
682
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
683
+ bg_source = BGSource(bg_source)
684
+
685
+ # bg_source = "Use Background Image"
686
+
687
+ # Convert numerical inputs to appropriate types
688
+ image_width = int(image_width)
689
+ image_height = int(image_height)
690
+ num_samples = int(num_samples)
691
+ seed = int(seed)
692
+ steps = int(steps)
693
+ cfg = float(cfg)
694
+ highres_scale = float(highres_scale)
695
+ highres_denoise = float(highres_denoise)
696
+
697
+ if bg_source == BGSource.UPLOAD:
698
+ pass
699
+ elif bg_source == BGSource.UPLOAD_FLIP:
700
+ input_bg = np.fliplr(input_bg)
701
+ elif bg_source == BGSource.GREY:
702
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
703
+ elif bg_source == BGSource.LEFT:
704
+ gradient = np.linspace(224, 32, image_width)
705
+ image = np.tile(gradient, (image_height, 1))
706
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
707
+ elif bg_source == BGSource.RIGHT:
708
+ gradient = np.linspace(32, 224, image_width)
709
+ image = np.tile(gradient, (image_height, 1))
710
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
711
+ elif bg_source == BGSource.TOP:
712
+ gradient = np.linspace(224, 32, image_height)[:, None]
713
+ image = np.tile(gradient, (1, image_width))
714
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
715
+ elif bg_source == BGSource.BOTTOM:
716
+ gradient = np.linspace(32, 224, image_height)[:, None]
717
+ image = np.tile(gradient, (1, image_width))
718
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
719
+ else:
720
+ raise ValueError('Wrong background source!')
721
+
722
+ input_fg, matting = run_rmbg(input_fg)
723
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
724
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
725
+ final_results = results + extra_images
726
+
727
+ # Save the generated images
728
+ save_images(results, prefix="relight")
729
+
730
+ return results
731
+
732
+
733
+ quick_prompts = [
734
+ 'sunshine from window',
735
+ 'neon light, city',
736
+ 'sunset over sea',
737
+ 'golden time',
738
+ 'sci-fi RGB glowing, cyberpunk',
739
+ 'natural lighting',
740
+ 'warm atmosphere, at home, bedroom',
741
+ 'magic lit',
742
+ 'evil, gothic, Yharnam',
743
+ 'light and shadow',
744
+ 'shadow from window',
745
+ 'soft studio lighting',
746
+ 'home atmosphere, cozy bedroom illumination',
747
+ 'neon, Wong Kar-wai, warm'
748
+ ]
749
+ quick_prompts = [[x] for x in quick_prompts]
750
+
751
+
752
+ quick_subjects = [
753
+ 'modern sofa, high quality leather',
754
+ 'elegant dining table, polished wood',
755
+ 'luxurious bed, premium mattress',
756
+ 'minimalist office desk, clean design',
757
+ 'vintage wooden cabinet, antique finish',
758
+ ]
759
+ quick_subjects = [[x] for x in quick_subjects]
760
+
761
+
762
+ class BGSource(Enum):
763
+ UPLOAD = "Use Background Image"
764
+ UPLOAD_FLIP = "Use Flipped Background Image"
765
+ LEFT = "Left Light"
766
+ RIGHT = "Right Light"
767
+ TOP = "Top Light"
768
+ BOTTOM = "Bottom Light"
769
+ GREY = "Ambient"
770
+
771
+ # Add save function
772
+ def save_images(images, prefix="relight"):
773
+ # Create output directory if it doesn't exist
774
+ output_dir = Path("outputs")
775
+ output_dir.mkdir(exist_ok=True)
776
+
777
+ # Create timestamp for unique filenames
778
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
779
+
780
+ saved_paths = []
781
+ for i, img in enumerate(images):
782
+ if isinstance(img, np.ndarray):
783
+ # Convert to PIL Image if numpy array
784
+ img = Image.fromarray(img)
785
+
786
+ # Create filename with timestamp
787
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
788
+ filepath = output_dir / filename
789
+
790
+ # Save image
791
+ img.save(filepath)
792
+
793
+
794
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
795
+ return saved_paths
796
+
797
+
798
+ class MaskMover:
799
+ def __init__(self):
800
+ self.extracted_fg = None
801
+ self.original_fg = None # Store original foreground
802
+
803
+ def set_extracted_fg(self, fg_image):
804
+ """Store the extracted foreground with alpha channel"""
805
+ if isinstance(fg_image, np.ndarray):
806
+ self.extracted_fg = fg_image.copy()
807
+ self.original_fg = fg_image.copy()
808
+ else:
809
+ self.extracted_fg = np.array(fg_image)
810
+ self.original_fg = np.array(fg_image)
811
+ return self.extracted_fg
812
+
813
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
814
+ """Create composite with foreground at specified position"""
815
+ if self.original_fg is None or background is None:
816
+ return background
817
+
818
+ # Convert inputs to PIL Images
819
+ if isinstance(background, np.ndarray):
820
+ bg = Image.fromarray(background).convert('RGBA')
821
+ else:
822
+ bg = background.convert('RGBA')
823
+
824
+ if isinstance(self.original_fg, np.ndarray):
825
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
826
+ else:
827
+ fg = self.original_fg.convert('RGBA')
828
+
829
+ # Scale the foreground size
830
+ new_width = int(fg.width * scale)
831
+ new_height = int(fg.height * scale)
832
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
833
+
834
+ # Center the scaled foreground at the position
835
+ x = int(x_pos - new_width / 2)
836
+ y = int(y_pos - new_height / 2)
837
+
838
+ # Create composite
839
+ result = bg.copy()
840
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
841
+
842
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
843
+
844
+ def get_depth(image):
845
+ if image is None:
846
+ return None
847
+ # Convert from PIL/gradio format to cv2
848
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
849
+ # Get depth map
850
+ depth = model.infer_image(raw_img) # HxW raw depth map
851
+ # Normalize depth for visualization
852
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
853
+ # Convert to RGB for display
854
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
855
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
856
+ return Image.fromarray(depth_colored)
857
+
858
+
859
+ from PIL import Image
860
+
861
+ def compress_image(image):
862
+ # Convert Gradio image (numpy array) to PIL Image
863
+ img = Image.fromarray(image)
864
+
865
+ # Resize image if dimensions are too large
866
+ max_size = 1024 # Maximum dimension size
867
+ if img.width > max_size or img.height > max_size:
868
+ ratio = min(max_size/img.width, max_size/img.height)
869
+ new_size = (int(img.width * ratio), int(img.height * ratio))
870
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
871
+
872
+ quality = 95 # Start with high quality
873
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
874
+
875
+ # Check file size and adjust quality if necessary
876
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
877
+ quality -= 5 # Decrease quality
878
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
879
+ if quality < 20: # Prevent quality from going too low
880
+ break
881
+
882
+ # Convert back to numpy array for Gradio
883
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
884
+ return compressed_img
885
+
886
+ def use_orientation(selected_image:gr.SelectData):
887
+ return selected_image.value['image']['path']
888
+
889
+ @spaces.GPU(duration=60)
890
+ @torch.inference_mode
891
+ def process_image(input_image, input_text):
892
+ """Main processing function for the Gradio interface"""
893
+
894
+ # Initialize configs
895
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
896
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
897
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
898
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
899
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
900
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
901
+
902
+
903
+
904
+ # Initialize DDS client
905
+ config = Config(API_TOKEN)
906
+ client = Client(config)
907
+
908
+ # Process classes from text prompt
909
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
910
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
911
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
912
+
913
+ # Save input image to temp file and get URL
914
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
915
+ cv2.imwrite(tmpfile.name, input_image)
916
+ image_url = client.upload_file(tmpfile.name)
917
+ os.remove(tmpfile.name)
918
+
919
+ # Process detection results
920
+ input_boxes = []
921
+ masks = []
922
+ confidences = []
923
+ class_names = []
924
+ class_ids = []
925
+
926
+ if len(input_text) == 0:
927
+ task = DinoxTask(
928
+ image_url=image_url,
929
+ prompts=[TextPrompt(text="<prompt_free>")],
930
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
931
+ )
932
+
933
+ client.run_task(task)
934
+ predictions = task.result.objects
935
+ classes = [pred.category for pred in predictions]
936
+ classes = list(set(classes))
937
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
938
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
939
+
940
+ for idx, obj in enumerate(predictions):
941
+ input_boxes.append(obj.bbox)
942
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
943
+ confidences.append(obj.score)
944
+ cls_name = obj.category.lower().strip()
945
+ class_names.append(cls_name)
946
+ class_ids.append(class_name_to_id[cls_name])
947
+
948
+ boxes = np.array(input_boxes)
949
+ masks = np.array(masks)
950
+ class_ids = np.array(class_ids)
951
+ labels = [
952
+ f"{class_name} {confidence:.2f}"
953
+ for class_name, confidence
954
+ in zip(class_names, confidences)
955
+ ]
956
+ detections = sv.Detections(
957
+ xyxy=boxes,
958
+ mask=masks.astype(bool),
959
+ class_id=class_ids
960
+ )
961
+
962
+ box_annotator = sv.BoxAnnotator()
963
+ label_annotator = sv.LabelAnnotator()
964
+ mask_annotator = sv.MaskAnnotator()
965
+
966
+ annotated_frame = input_image.copy()
967
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
968
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
969
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
970
+
971
+ # Create transparent mask for first detected object
972
+ if len(detections) > 0:
973
+ # Get first mask
974
+ first_mask = detections.mask[0]
975
+
976
+ # Get original RGB image
977
+ img = input_image.copy()
978
+ H, W, C = img.shape
979
+
980
+ # Create RGBA image
981
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
982
+ alpha[first_mask] = 255
983
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
984
+
985
+ # Crop to mask bounds to minimize image size
986
+ y_indices, x_indices = np.where(first_mask)
987
+ y_min, y_max = y_indices.min(), y_indices.max()
988
+ x_min, x_max = x_indices.min(), x_indices.max()
989
+
990
+ # Crop the RGBA image
991
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
992
+
993
+ # Set extracted foreground for mask mover
994
+ mask_mover.set_extracted_fg(cropped_rgba)
995
+
996
+ return annotated_frame, cropped_rgba, gr.update(visible=False), gr.update(visible=False)
997
+
998
+
999
+ else:
1000
+ # Run DINO-X detection
1001
+ task = DinoxTask(
1002
+ image_url=image_url,
1003
+ prompts=[TextPrompt(text=input_text)],
1004
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1005
+ )
1006
+
1007
+ client.run_task(task)
1008
+ result = task.result
1009
+ objects = result.objects
1010
+
1011
+
1012
+
1013
+ # for obj in objects:
1014
+ # input_boxes.append(obj.bbox)
1015
+ # confidences.append(obj.score)
1016
+ # cls_name = obj.category.lower().strip()
1017
+ # class_names.append(cls_name)
1018
+ # class_ids.append(class_name_to_id[cls_name])
1019
+
1020
+ # input_boxes = np.array(input_boxes)
1021
+ # class_ids = np.array(class_ids)
1022
+
1023
+ predictions = task.result.objects
1024
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1025
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1026
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1027
+
1028
+ boxes = []
1029
+ masks = []
1030
+ confidences = []
1031
+ class_names = []
1032
+ class_ids = []
1033
+
1034
+ for idx, obj in enumerate(predictions):
1035
+ boxes.append(obj.bbox)
1036
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1037
+ confidences.append(obj.score)
1038
+ cls_name = obj.category.lower().strip()
1039
+ class_names.append(cls_name)
1040
+ class_ids.append(class_name_to_id[cls_name])
1041
+
1042
+ boxes = np.array(boxes)
1043
+ masks = np.array(masks)
1044
+ class_ids = np.array(class_ids)
1045
+ labels = [
1046
+ f"{class_name} {confidence:.2f}"
1047
+ for class_name, confidence
1048
+ in zip(class_names, confidences)
1049
+ ]
1050
+
1051
+ # Initialize SAM2
1052
+ # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
1053
+ # if torch.cuda.get_device_properties(0).major >= 8:
1054
+ # torch.backends.cuda.matmul.allow_tf32 = True
1055
+ # torch.backends.cudnn.allow_tf32 = True
1056
+
1057
+ # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
1058
+ # sam2_predictor = SAM2ImagePredictor(sam2_model)
1059
+ # sam2_predictor.set_image(input_image)
1060
+
1061
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
1062
+
1063
+
1064
+ # Get masks from SAM2
1065
+ # masks, scores, logits = sam2_predictor.predict(
1066
+ # point_coords=None,
1067
+ # point_labels=None,
1068
+ # box=input_boxes,
1069
+ # multimask_output=False,
1070
+ # )
1071
+
1072
+ if masks.ndim == 4:
1073
+ masks = masks.squeeze(1)
1074
+
1075
+ # Create visualization
1076
+ # labels = [f"{class_name} {confidence:.2f}"
1077
+ # for class_name, confidence in zip(class_names, confidences)]
1078
+
1079
+ # detections = sv.Detections(
1080
+ # xyxy=input_boxes,
1081
+ # mask=masks.astype(bool),
1082
+ # class_id=class_ids
1083
+ # )
1084
+
1085
+ detections = sv.Detections(
1086
+ xyxy = boxes,
1087
+ mask = masks.astype(bool),
1088
+ class_id = class_ids,
1089
+ )
1090
+
1091
+ box_annotator = sv.BoxAnnotator()
1092
+ label_annotator = sv.LabelAnnotator()
1093
+ mask_annotator = sv.MaskAnnotator()
1094
+
1095
+ annotated_frame = input_image.copy()
1096
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1097
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1098
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1099
+
1100
+ # Create transparent mask for first detected object
1101
+ if len(detections) > 0:
1102
+ # Get first mask
1103
+ first_mask = detections.mask[0]
1104
+
1105
+ # Get original RGB image
1106
+ img = input_image.copy()
1107
+ H, W, C = img.shape
1108
+
1109
+ # Create RGBA image
1110
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1111
+ alpha[first_mask] = 255
1112
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1113
+
1114
+ # Crop to mask bounds to minimize image size
1115
+ y_indices, x_indices = np.where(first_mask)
1116
+ y_min, y_max = y_indices.min(), y_indices.max()
1117
+ x_min, x_max = x_indices.min(), x_indices.max()
1118
+
1119
+ # Crop the RGBA image
1120
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1121
+
1122
+ # Set extracted foreground for mask mover
1123
+ mask_mover.set_extracted_fg(cropped_rgba)
1124
+
1125
+ return annotated_frame, cropped_rgba, gr.update(visible=False), gr.update(visible=False)
1126
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1127
+
1128
+
1129
+ block = gr.Blocks().queue()
1130
+ with block:
1131
+ with gr.Tab("Text"):
1132
+ with gr.Row():
1133
+ gr.Markdown("## Product Placement from Text")
1134
+ with gr.Row():
1135
+ with gr.Column():
1136
+ with gr.Row():
1137
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
1138
+ with gr.Row():
1139
+ with gr.Group():
1140
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1141
+ text_prompt = gr.Textbox(
1142
+ label="Text Prompt",
1143
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1144
+ value=""
1145
+ )
1146
+ extract_button = gr.Button(value="Remove Background")
1147
+ with gr.Row():
1148
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1149
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1150
+
1151
+ with gr.Row():
1152
+ run_button = gr.Button("Generate alternative angles")
1153
+
1154
+ gallery_result = gr.Gallery(
1155
+ label="Result",
1156
+ show_label=False,
1157
+ columns=[3],
1158
+ rows=[2],
1159
+ object_fit="contain",
1160
+ height="auto",
1161
+ allow_preview=False,
1162
+ )
1163
+
1164
+ if gallery_result:
1165
+ selected = gr.Number(visible=True)
1166
+ gallery_result.select(use_orientation, inputs=None, outputs=extract_fg)
1167
+
1168
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
1169
+ with gr.Group():
1170
+ prompt = gr.Textbox(label="Prompt")
1171
+ bg_source = gr.Radio(choices=[e.value for e in list(BGSource)[2:]],
1172
+ value=BGSource.LEFT.value,
1173
+ label="Lighting Preference (Initial Latent)", type='value')
1174
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
1175
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
1176
+ relight_button = gr.Button(value="Relight")
1177
+
1178
+ with gr.Group(visible=False):
1179
+ with gr.Row():
1180
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1181
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1182
+
1183
+ with gr.Row():
1184
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1185
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1186
+
1187
+ with gr.Accordion("Advanced options", open=False):
1188
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
1189
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01, visible=False)
1190
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
1191
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
1192
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
1193
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality', visible=False)
1194
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality', visible=False)
1195
+ x_slider = gr.Slider(
1196
+ minimum=0,
1197
+ maximum=1000,
1198
+ label="X Position",
1199
+ value=500,
1200
+ visible=False
1201
+ )
1202
+ y_slider = gr.Slider(
1203
+ minimum=0,
1204
+ maximum=1000,
1205
+ label="Y Position",
1206
+ value=500,
1207
+ visible=False
1208
+ )
1209
+ with gr.Column():
1210
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
1211
+ with gr.Row():
1212
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
1213
+ # gr.Examples(
1214
+ # fn=lambda *args: ([args[-1]], None),
1215
+ # examples=db_examples.foreground_conditioned_examples,
1216
+ # inputs=[
1217
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
1218
+ # ],
1219
+ # outputs=[result_gallery, output_bg],
1220
+ # run_on_click=True, examples_per_page=1024
1221
+ # )
1222
+ ips = [extracted_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
1223
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
1224
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1225
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1226
+
1227
+ run_button.click(fn=infer,
1228
+ inputs=[
1229
+ "high quality",
1230
+ extracted_fg,
1231
+ ],
1232
+ outputs=[result],
1233
+ )
1234
+
1235
+ find_objects_button.click(
1236
+ fn=process_image,
1237
+ inputs=[input_fg, text_prompt],
1238
+ outputs=[extracted_objects, extracted_fg]
1239
+ )
1240
+
1241
+ extract_button.click(
1242
+ fn=extract_foreground,
1243
+ inputs=[input_fg],
1244
+ outputs=[extracted_fg, x_slider, y_slider]
1245
+ )
1246
+ with gr.Tab("Background", visible=False):
1247
+ # empty cache
1248
+
1249
+ mask_mover = MaskMover()
1250
+
1251
+ # with torch.no_grad():
1252
+ # # Update the input channels to 12
1253
+ # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) # Changed from 8 to 12
1254
+ # new_conv_in.weight.zero_()
1255
+ # new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
1256
+ # new_conv_in.bias = unet.conv_in.bias
1257
+ # unet.conv_in = new_conv_in
1258
+ with gr.Row():
1259
+ gr.Markdown("## IC-Light (Relighting with Foreground and Background Condition)")
1260
+ gr.Markdown("💾 Generated images are automatically saved to 'outputs' folder")
1261
+
1262
+ with gr.Row():
1263
+ with gr.Column():
1264
+ # Step 1: Input and Extract
1265
+ with gr.Row():
1266
+ with gr.Group():
1267
+ gr.Markdown("### Step 1: Extract Foreground")
1268
+ input_image = gr.Image(type="numpy", label="Input Image", height=480)
1269
+ # find_objects_button = gr.Button(value="Find Objects")
1270
+ extract_button = gr.Button(value="Remove Background")
1271
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1272
+
1273
+ with gr.Row():
1274
+ # Step 2: Background and Position
1275
+ with gr.Group():
1276
+ gr.Markdown("### Step 2: Position on Background")
1277
+ input_bg = gr.Image(type="numpy", label="Background Image", height=480)
1278
+
1279
+ with gr.Row():
1280
+ x_slider = gr.Slider(
1281
+ minimum=0,
1282
+ maximum=1000,
1283
+ label="X Position",
1284
+ value=500,
1285
+ visible=False
1286
+ )
1287
+ y_slider = gr.Slider(
1288
+ minimum=0,
1289
+ maximum=1000,
1290
+ label="Y Position",
1291
+ value=500,
1292
+ visible=False
1293
+ )
1294
+ fg_scale_slider = gr.Slider(
1295
+ label="Foreground Scale",
1296
+ minimum=0.01,
1297
+ maximum=3.0,
1298
+ value=1.0,
1299
+ step=0.01
1300
+ )
1301
+
1302
+ editor = gr.ImageEditor(
1303
+ type="numpy",
1304
+ label="Position Foreground",
1305
+ height=480,
1306
+ visible=False
1307
+ )
1308
+ get_depth_button = gr.Button(value="Get Depth")
1309
+ depth_image = gr.Image(type="numpy", label="Depth Image", height=480)
1310
+
1311
+ # Step 3: Relighting Options
1312
+ with gr.Group():
1313
+ gr.Markdown("### Step 3: Relighting Settings")
1314
+ prompt = gr.Textbox(label="Prompt")
1315
+ bg_source = gr.Radio(
1316
+ choices=[e.value for e in BGSource],
1317
+ value=BGSource.UPLOAD.value,
1318
+ label="Background Source",
1319
+ type='value',
1320
+ visible=False
1321
+ )
1322
+
1323
+ example_prompts = gr.Dataset(
1324
+ samples=quick_prompts,
1325
+ label='Prompt Quick List',
1326
+ components=[prompt]
1327
+ )
1328
+ # bg_gallery = gr.Gallery(
1329
+ # height=450,
1330
+ # label='Background Quick List',
1331
+ # value=db_examples.bg_samples,
1332
+ # columns=5,
1333
+ # allow_preview=False
1334
+ # )
1335
+ relight_button_bg = gr.Button(value="Relight")
1336
+
1337
+ # Additional settings
1338
+ with gr.Group():
1339
+ with gr.Row():
1340
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1341
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1342
+ with gr.Row():
1343
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1344
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1345
+
1346
+ with gr.Accordion("Advanced options", open=False):
1347
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1)
1348
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01)
1349
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=2.0, value=1.2, step=0.01)
1350
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01)
1351
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality')
1352
+ n_prompt = gr.Textbox(
1353
+ label="Negative Prompt",
1354
+ value='lowres, bad anatomy, bad hands, cropped, worst quality'
1355
+ )
1356
+
1357
+ with gr.Column():
1358
+ result_gallery = gr.Image(height=832, label='Outputs')
1359
+
1360
+ def extract_foreground(image):
1361
+ if image is None:
1362
+ return None, gr.update(visible=True), gr.update(visible=True)
1363
+ result, rgba = run_rmbg(image)
1364
+ mask_mover.set_extracted_fg(rgba)
1365
+
1366
+ return result, gr.update(visible=True), gr.update(visible=True)
1367
+
1368
+
1369
+ original_bg = None
1370
+
1371
+ extract_button.click(
1372
+ fn=extract_foreground,
1373
+ inputs=[input_image],
1374
+ outputs=[extracted_fg, x_slider, y_slider]
1375
+ )
1376
+
1377
+ find_objects_button.click(
1378
+ fn=process_image,
1379
+ inputs=[input_image, text_prompt],
1380
+ outputs=[extracted_objects, extracted_fg, x_slider, y_slider]
1381
+ )
1382
+
1383
+ get_depth_button.click(
1384
+ fn=get_depth,
1385
+ inputs=[input_bg],
1386
+ outputs=[depth_image]
1387
+ )
1388
+
1389
+
1390
+
1391
+ # def update_position(background, x_pos, y_pos, scale):
1392
+ # """Update composite when position changes"""
1393
+ # global original_bg
1394
+ # if background is None:
1395
+ # return None
1396
+
1397
+ # if original_bg is None:
1398
+ # original_bg = background.copy()
1399
+
1400
+ # # Convert string values to float
1401
+ # x_pos = float(x_pos)
1402
+ # y_pos = float(y_pos)
1403
+ # scale = float(scale)
1404
+
1405
+ # return mask_mover.create_composite(original_bg, x_pos, y_pos, scale)
1406
+
1407
+ class BackgroundManager:
1408
+ def __init__(self):
1409
+ self.original_bg = None
1410
+
1411
+ def update_position(self, background, x_pos, y_pos, scale):
1412
+ """Update composite when position changes"""
1413
+ if background is None:
1414
+ return None
1415
+
1416
+ if self.original_bg is None:
1417
+ self.original_bg = background.copy()
1418
+
1419
+ # Convert string values to float
1420
+ x_pos = float(x_pos)
1421
+ y_pos = float(y_pos)
1422
+ scale = float(scale)
1423
+
1424
+ return mask_mover.create_composite(self.original_bg, x_pos, y_pos, scale)
1425
+
1426
+ # Create an instance of BackgroundManager
1427
+ bg_manager = BackgroundManager()
1428
+
1429
+
1430
+ x_slider.change(
1431
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1432
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1433
+ outputs=[input_bg]
1434
+ )
1435
+
1436
+ y_slider.change(
1437
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1438
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1439
+ outputs=[input_bg]
1440
+ )
1441
+
1442
+ fg_scale_slider.change(
1443
+ fn=lambda bg, x, y, scale: bg_manager.update_position(bg, x, y, scale),
1444
+ inputs=[input_bg, x_slider, y_slider, fg_scale_slider],
1445
+ outputs=[input_bg]
1446
+ )
1447
+
1448
+ # Update inputs list to include fg_scale_slider
1449
+
1450
+ def process_relight_with_position(*args):
1451
+ if mask_mover.extracted_fg is None:
1452
+ gr.Warning("Please extract foreground first")
1453
+ return None
1454
+
1455
+ background = args[1] # Get background image
1456
+ x_pos = float(args[-3]) # x_slider value
1457
+ y_pos = float(args[-2]) # y_slider value
1458
+ scale = float(args[-1]) # fg_scale_slider value
1459
+
1460
+ # Get original foreground size after scaling
1461
+ fg = Image.fromarray(mask_mover.original_fg)
1462
+ new_width = int(fg.width * scale)
1463
+ new_height = int(fg.height * scale)
1464
+
1465
+ # Calculate crop region around foreground position
1466
+ crop_x = int(x_pos - new_width/2)
1467
+ crop_y = int(y_pos - new_height/2)
1468
+ crop_width = new_width
1469
+ crop_height = new_height
1470
+
1471
+ # Add padding for context (20% extra on each side)
1472
+ padding = 0.2
1473
+ crop_x = int(crop_x - crop_width * padding)
1474
+ crop_y = int(crop_y - crop_height * padding)
1475
+ crop_width = int(crop_width * (1 + 2 * padding))
1476
+ crop_height = int(crop_height * (1 + 2 * padding))
1477
+
1478
+ # Ensure crop dimensions are multiples of 8
1479
+ crop_width = ((crop_width + 7) // 8) * 8
1480
+ crop_height = ((crop_height + 7) // 8) * 8
1481
+
1482
+ # Ensure crop region is within image bounds
1483
+ bg_height, bg_width = background.shape[:2]
1484
+ crop_x = max(0, min(crop_x, bg_width - crop_width))
1485
+ crop_y = max(0, min(crop_y, bg_height - crop_height))
1486
+
1487
+ # Get actual crop dimensions after boundary check
1488
+ crop_width = min(crop_width, bg_width - crop_x)
1489
+ crop_height = min(crop_height, bg_height - crop_y)
1490
+
1491
+ # Ensure dimensions are multiples of 8 again
1492
+ crop_width = (crop_width // 8) * 8
1493
+ crop_height = (crop_height // 8) * 8
1494
+
1495
+ # Crop region from background
1496
+ crop_region = background[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width]
1497
+
1498
+ # Create composite in cropped region
1499
+ fg_local_x = int(new_width/2 + crop_width*padding)
1500
+ fg_local_y = int(new_height/2 + crop_height*padding)
1501
+ cropped_composite = mask_mover.create_composite(crop_region, fg_local_x, fg_local_y, scale)
1502
+
1503
+ # Process the cropped region
1504
+ crop_args = list(args)
1505
+ crop_args[0] = cropped_composite
1506
+ crop_args[1] = crop_region
1507
+ crop_args[3] = crop_width
1508
+ crop_args[4] = crop_height
1509
+ crop_args = crop_args[:-3] # Remove position and scale arguments
1510
+
1511
+ # Get relit result
1512
+ relit_crop = process_relight_bg(*crop_args)[0]
1513
+
1514
+ # Resize relit result to match crop dimensions if needed
1515
+ if relit_crop.shape[:2] != (crop_height, crop_width):
1516
+ relit_crop = resize_without_crop(relit_crop, crop_width, crop_height)
1517
+
1518
+ # Place relit crop back into original background
1519
+ result = background.copy()
1520
+ result[crop_y:crop_y+crop_height, crop_x:crop_x+crop_width] = relit_crop
1521
+
1522
+ return result
1523
+
1524
+ ips_bg = [input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source]
1525
+
1526
+ # Update button click events with new inputs list
1527
+ relight_button_bg.click(
1528
+ fn=process_relight_with_position,
1529
+ inputs=ips_bg,
1530
+ outputs=[result_gallery]
1531
+ )
1532
+
1533
+
1534
+ example_prompts.click(
1535
+ fn=lambda x: x[0],
1536
+ inputs=example_prompts,
1537
+ outputs=prompt,
1538
+ show_progress=False,
1539
+ queue=False
1540
+ )
1541
+
1542
+
1543
+ block.launch(server_name='0.0.0.0', share=False)