Ashoka74 commited on
Commit
4e97c85
1 Parent(s): b9d2b0c

Create app_3.py

Browse files
Files changed (1) hide show
  1. app_3.py +1302 -0
app_3.py ADDED
@@ -0,0 +1,1302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import argparse
3
+ import random
4
+
5
+ import os
6
+ import math
7
+ import gradio as gr
8
+ import numpy as np
9
+ import torch
10
+ import safetensors.torch as sf
11
+ import datetime
12
+ from pathlib import Path
13
+ from io import BytesIO
14
+
15
+
16
+
17
+ from PIL import Image
18
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
19
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
20
+ from diffusers.models.attention_processor import AttnProcessor2_0
21
+ from transformers import CLIPTextModel, CLIPTokenizer
22
+ import dds_cloudapi_sdk
23
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
24
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
25
+ from dds_cloudapi_sdk.tasks import DetectionTarget
26
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
27
+
28
+ from enum import Enum
29
+ from torch.hub import download_url_to_file
30
+ import tempfile
31
+
32
+ from sam2.build_sam import build_sam2
33
+
34
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
35
+ import cv2
36
+
37
+ from transformers import AutoModelForImageSegmentation
38
+ from inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
39
+ from torchvision import transforms
40
+
41
+
42
+ from typing import Optional
43
+
44
+ from dpt import DepthAnythingV2
45
+
46
+ import httpx
47
+
48
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
49
+ NUM_VIEWS = 6
50
+ HEIGHT = 768
51
+ WIDTH = 768
52
+ MAX_SEED = np.iinfo(np.int32).max
53
+
54
+
55
+ import supervision as sv
56
+ import torch
57
+ from PIL import Image
58
+
59
+ import logging
60
+
61
+ # Configure logging
62
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
63
+
64
+ transform_image = transforms.Compose(
65
+ [
66
+ transforms.Resize((1024, 1024)),
67
+ transforms.ToTensor(),
68
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
69
+ ]
70
+ )
71
+
72
+ # Load
73
+
74
+ # Model paths
75
+ model_path = './models/iclight_sd15_fc.safetensors'
76
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
77
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
78
+ model_path4 = './checkpoints/config.json'
79
+ model_path5 = './checkpoints/preprocessor_config.json'
80
+ model_path6 = './configs/sam2_hiera_l.yaml'
81
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
82
+
83
+ # Base URL for the repository
84
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
85
+
86
+ # Model URLs
87
+ model_urls = {
88
+ model_path: 'iclight_sd15_fc.safetensors',
89
+ model_path2: 'depth_anything_v2_vits.pth',
90
+ model_path3: 'sam2_hiera_large.pt',
91
+ model_path4: 'config.json',
92
+ model_path5: 'preprocessor_config.json',
93
+ model_path6: 'sam2_hiera_l.yaml',
94
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
95
+ }
96
+
97
+ # Ensure directories exist
98
+ def ensure_directories():
99
+ for path in model_urls.keys():
100
+ os.makedirs(os.path.dirname(path), exist_ok=True)
101
+
102
+ # Download models
103
+ def download_models():
104
+ for local_path, filename in model_urls.items():
105
+ if not os.path.exists(local_path):
106
+ try:
107
+ url = f"{BASE_URL}{filename}"
108
+ print(f"Downloading {filename}")
109
+ download_url_to_file(url, local_path)
110
+ print(f"Successfully downloaded {filename}")
111
+ except Exception as e:
112
+ print(f"Error downloading {filename}: {e}")
113
+
114
+ ensure_directories()
115
+
116
+ download_models()
117
+
118
+
119
+
120
+
121
+
122
+ try:
123
+ import xformers
124
+ import xformers.ops
125
+ XFORMERS_AVAILABLE = True
126
+ print("xformers is available - Using memory efficient attention")
127
+ except ImportError:
128
+ XFORMERS_AVAILABLE = False
129
+ print("xformers not available - Using default attention")
130
+
131
+ # Memory optimizations for RTX 2070
132
+ torch.backends.cudnn.benchmark = True
133
+ if torch.cuda.is_available():
134
+ torch.backends.cuda.matmul.allow_tf32 = True
135
+ torch.backends.cudnn.allow_tf32 = True
136
+ # Set a smaller attention slice size for RTX 2070
137
+ torch.backends.cuda.max_split_size_mb = 512
138
+ device = torch.device('cuda')
139
+ else:
140
+ device = torch.device('cpu')
141
+
142
+
143
+
144
+ # 'stablediffusionapi/realistic-vision-v51'
145
+ # 'runwayml/stable-diffusion-v1-5'
146
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
147
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
148
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
149
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
150
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
151
+ # Load model directly
152
+ from transformers import AutoModelForImageSegmentation
153
+ rmbg = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
154
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
155
+
156
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
157
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
158
+ model = model.to(device)
159
+ model.eval()
160
+
161
+ # Change UNet
162
+
163
+
164
+ with torch.no_grad():
165
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
166
+ new_conv_in.weight.zero_()
167
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
168
+ new_conv_in.bias = unet.conv_in.bias
169
+ unet.conv_in = new_conv_in
170
+
171
+
172
+ unet_original_forward = unet.forward
173
+
174
+
175
+ def enable_efficient_attention():
176
+ if XFORMERS_AVAILABLE:
177
+ try:
178
+ # RTX 2070 specific settings
179
+ unet.set_use_memory_efficient_attention_xformers(True)
180
+ vae.set_use_memory_efficient_attention_xformers(True)
181
+ print("Enabled xformers memory efficient attention")
182
+ except Exception as e:
183
+ print(f"Xformers error: {e}")
184
+ print("Falling back to sliced attention")
185
+ # Use sliced attention for RTX 2070
186
+ # unet.set_attention_slice_size(4)
187
+ # vae.set_attention_slice_size(4)
188
+ unet.set_attn_processor(AttnProcessor2_0())
189
+ vae.set_attn_processor(AttnProcessor2_0())
190
+ else:
191
+ # Fallback for when xformers is not available
192
+ print("Using sliced attention")
193
+ # unet.set_attention_slice_size(4)
194
+ # vae.set_attention_slice_size(4)
195
+ unet.set_attn_processor(AttnProcessor2_0())
196
+ vae.set_attn_processor(AttnProcessor2_0())
197
+
198
+ # Add memory clearing function
199
+ def clear_memory():
200
+ if torch.cuda.is_available():
201
+ torch.cuda.empty_cache()
202
+ torch.cuda.synchronize()
203
+
204
+ # Enable efficient attention
205
+ enable_efficient_attention()
206
+
207
+
208
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
209
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
210
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
211
+ new_sample = torch.cat([sample, c_concat], dim=1)
212
+ kwargs['cross_attention_kwargs'] = {}
213
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
214
+
215
+
216
+ unet.forward = hooked_unet_forward
217
+
218
+
219
+
220
+
221
+ sd_offset = sf.load_file(model_path)
222
+ sd_origin = unet.state_dict()
223
+ keys = sd_origin.keys()
224
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
225
+ unet.load_state_dict(sd_merged, strict=True)
226
+ del sd_offset, sd_origin, sd_merged, keys
227
+
228
+ # Device
229
+
230
+ # device = torch.device('cuda')
231
+ # text_encoder = text_encoder.to(device=device, dtype=torch.float16)
232
+ # vae = vae.to(device=device, dtype=torch.bfloat16)
233
+ # unet = unet.to(device=device, dtype=torch.float16)
234
+ # rmbg = rmbg.to(device=device, dtype=torch.float32)
235
+
236
+
237
+ # Device and dtype setup
238
+ device = torch.device('cuda')
239
+ dtype = torch.float16 # RTX 2070 works well with float16
240
+
241
+ pipe = prepare_pipeline(
242
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
243
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
244
+ unet_model=None,
245
+ lora_model=None,
246
+ adapter_path="huanngzh/mv-adapter",
247
+ scheduler=None,
248
+ num_views=NUM_VIEWS,
249
+ device=device,
250
+ dtype=dtype,
251
+ )
252
+
253
+ # Memory optimizations for RTX 2070
254
+ torch.backends.cudnn.benchmark = True
255
+ if torch.cuda.is_available():
256
+ torch.backends.cuda.matmul.allow_tf32 = True
257
+ torch.backends.cudnn.allow_tf32 = True
258
+ # Set a very small attention slice size for RTX 2070 to avoid OOM
259
+ torch.backends.cuda.max_split_size_mb = 128
260
+
261
+ # Move models to device with consistent dtype
262
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
263
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
264
+ unet = unet.to(device=device, dtype=dtype)
265
+ rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
266
+
267
+
268
+ ddim_scheduler = DDIMScheduler(
269
+ num_train_timesteps=1000,
270
+ beta_start=0.00085,
271
+ beta_end=0.012,
272
+ beta_schedule="scaled_linear",
273
+ clip_sample=False,
274
+ set_alpha_to_one=False,
275
+ steps_offset=1,
276
+ )
277
+
278
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
279
+ num_train_timesteps=1000,
280
+ beta_start=0.00085,
281
+ beta_end=0.012,
282
+ steps_offset=1
283
+ )
284
+
285
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
286
+ num_train_timesteps=1000,
287
+ beta_start=0.00085,
288
+ beta_end=0.012,
289
+ algorithm_type="sde-dpmsolver++",
290
+ use_karras_sigmas=True,
291
+ steps_offset=1
292
+ )
293
+
294
+ # Pipelines
295
+
296
+ t2i_pipe = StableDiffusionPipeline(
297
+ vae=vae,
298
+ text_encoder=text_encoder,
299
+ tokenizer=tokenizer,
300
+ unet=unet,
301
+ scheduler=dpmpp_2m_sde_karras_scheduler,
302
+ safety_checker=None,
303
+ requires_safety_checker=False,
304
+ feature_extractor=None,
305
+ image_encoder=None
306
+ )
307
+
308
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
309
+ vae=vae,
310
+ text_encoder=text_encoder,
311
+ tokenizer=tokenizer,
312
+ unet=unet,
313
+ scheduler=dpmpp_2m_sde_karras_scheduler,
314
+ safety_checker=None,
315
+ requires_safety_checker=False,
316
+ feature_extractor=None,
317
+ image_encoder=None
318
+ )
319
+
320
+
321
+ @torch.inference_mode()
322
+ def encode_prompt_inner(txt: str):
323
+ max_length = tokenizer.model_max_length
324
+ chunk_length = tokenizer.model_max_length - 2
325
+ id_start = tokenizer.bos_token_id
326
+ id_end = tokenizer.eos_token_id
327
+ id_pad = id_end
328
+
329
+ def pad(x, p, i):
330
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
331
+
332
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
333
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
334
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
335
+
336
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
337
+ conds = text_encoder(token_ids).last_hidden_state
338
+
339
+ return conds
340
+
341
+
342
+ @torch.inference_mode()
343
+ def encode_prompt_pair(positive_prompt, negative_prompt):
344
+ c = encode_prompt_inner(positive_prompt)
345
+ uc = encode_prompt_inner(negative_prompt)
346
+
347
+ c_len = float(len(c))
348
+ uc_len = float(len(uc))
349
+ max_count = max(c_len, uc_len)
350
+ c_repeat = int(math.ceil(max_count / c_len))
351
+ uc_repeat = int(math.ceil(max_count / uc_len))
352
+ max_chunk = max(len(c), len(uc))
353
+
354
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
355
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
356
+
357
+ c = torch.cat([p[None, ...] for p in c], dim=1)
358
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
359
+
360
+ return c, uc
361
+
362
+ @spaces.GPU(duration=60)
363
+ @torch.inference_mode()
364
+ @spaces.GPU(duration=60)
365
+ @torch.inference_mode()
366
+ def infer(
367
+ prompt,
368
+ image, # This is already RGBA with background removed
369
+ do_rembg=True,
370
+ seed=42,
371
+ randomize_seed=False,
372
+ guidance_scale=3.0,
373
+ num_inference_steps=50,
374
+ reference_conditioning_scale=1.0,
375
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
376
+ progress=gr.Progress(track_tqdm=True),
377
+ ):
378
+ logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
379
+
380
+ # Convert input to PIL if needed
381
+ if isinstance(image, np.ndarray):
382
+ if image.shape[-1] == 4: # RGBA
383
+ image = Image.fromarray(image, 'RGBA')
384
+ else: # RGB
385
+ image = Image.fromarray(image, 'RGB')
386
+
387
+ logging.info(f"Converted to PIL Image mode: {image.mode}")
388
+
389
+ # No need for remove_bg_fn since image is already processed
390
+ remove_bg_fn = None
391
+
392
+ if randomize_seed:
393
+ seed = random.randint(0, MAX_SEED)
394
+
395
+ images, preprocessed_image = run_pipeline(
396
+ pipe,
397
+ num_views=NUM_VIEWS,
398
+ text=prompt,
399
+ image=image,
400
+ height=HEIGHT,
401
+ width=WIDTH,
402
+ num_inference_steps=num_inference_steps,
403
+ guidance_scale=guidance_scale,
404
+ seed=seed,
405
+ remove_bg_fn=remove_bg_fn, # Set to None since preprocessing is done
406
+ reference_conditioning_scale=reference_conditioning_scale,
407
+ negative_prompt=negative_prompt,
408
+ device=device,
409
+ )
410
+
411
+ logging.info(f"Output images shape: {[img.shape for img in images]}")
412
+ logging.info(f"Preprocessed image shape: {preprocessed_image.shape if preprocessed_image is not None else None}")
413
+ return images
414
+
415
+
416
+ @spaces.GPU(duration=60)
417
+ @torch.inference_mode()
418
+ def pytorch2numpy(imgs, quant=True):
419
+ results = []
420
+ for x in imgs:
421
+ y = x.movedim(0, -1)
422
+
423
+ if quant:
424
+ y = y * 127.5 + 127.5
425
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
426
+ else:
427
+ y = y * 0.5 + 0.5
428
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
429
+
430
+ results.append(y)
431
+ return results
432
+
433
+ @spaces.GPU(duration=60)
434
+ @torch.inference_mode()
435
+ def numpy2pytorch(imgs):
436
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
437
+ h = h.movedim(-1, 1)
438
+ return h
439
+
440
+
441
+ def resize_and_center_crop(image, target_width, target_height):
442
+ pil_image = Image.fromarray(image)
443
+ original_width, original_height = pil_image.size
444
+ scale_factor = max(target_width / original_width, target_height / original_height)
445
+ resized_width = int(round(original_width * scale_factor))
446
+ resized_height = int(round(original_height * scale_factor))
447
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
448
+ left = (resized_width - target_width) / 2
449
+ top = (resized_height - target_height) / 2
450
+ right = (resized_width + target_width) / 2
451
+ bottom = (resized_height + target_height) / 2
452
+ cropped_image = resized_image.crop((left, top, right, bottom))
453
+ return np.array(cropped_image)
454
+
455
+
456
+ def resize_without_crop(image, target_width, target_height):
457
+ pil_image = Image.fromarray(image)
458
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
459
+ return np.array(resized_image)
460
+
461
+ @spaces.GPU(duration=60)
462
+ @torch.inference_mode()
463
+ def run_rmbg(img, sigma=0.0):
464
+ # Convert RGBA to RGB if needed
465
+ if img.shape[-1] == 4:
466
+ # Use white background for alpha composition
467
+ alpha = img[..., 3:] / 255.0
468
+ rgb = img[..., :3]
469
+ white_bg = np.ones_like(rgb) * 255
470
+ img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
471
+
472
+ H, W, C = img.shape
473
+ assert C == 3
474
+ k = (256.0 / float(H * W)) ** 0.5
475
+ feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
476
+ feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
477
+ alpha = rmbg(feed)[0][0]
478
+ alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
479
+ alpha = alpha.movedim(1, -1)[0]
480
+ alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
481
+
482
+ # Create RGBA image
483
+ rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
484
+ result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
485
+ return result.clip(0, 255).astype(np.uint8), rgba
486
+
487
+ @spaces.GPU(duration=60)
488
+ @torch.inference_mode()
489
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
490
+ clear_memory()
491
+
492
+ # Get input dimensions
493
+ input_height, input_width = input_fg.shape[:2]
494
+
495
+ bg_source = BGSource(bg_source)
496
+
497
+
498
+ if bg_source == BGSource.UPLOAD:
499
+ pass
500
+ elif bg_source == BGSource.UPLOAD_FLIP:
501
+ input_bg = np.fliplr(input_bg)
502
+ if bg_source == BGSource.GREY:
503
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
504
+ elif bg_source == BGSource.LEFT:
505
+ gradient = np.linspace(255, 0, input_width)
506
+ image = np.tile(gradient, (input_height, 1))
507
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
508
+ elif bg_source == BGSource.RIGHT:
509
+ gradient = np.linspace(0, 255, input_width)
510
+ image = np.tile(gradient, (input_height, 1))
511
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
512
+ elif bg_source == BGSource.TOP:
513
+ gradient = np.linspace(255, 0, input_height)[:, None]
514
+ image = np.tile(gradient, (1, input_width))
515
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
516
+ elif bg_source == BGSource.BOTTOM:
517
+ gradient = np.linspace(0, 255, input_height)[:, None]
518
+ image = np.tile(gradient, (1, input_width))
519
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
520
+ else:
521
+ raise 'Wrong initial latent!'
522
+
523
+ rng = torch.Generator(device=device).manual_seed(int(seed))
524
+
525
+ # Use input dimensions directly
526
+ fg = resize_without_crop(input_fg, input_width, input_height)
527
+
528
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
529
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
530
+
531
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
532
+
533
+ if input_bg is None:
534
+ latents = t2i_pipe(
535
+ prompt_embeds=conds,
536
+ negative_prompt_embeds=unconds,
537
+ width=input_width,
538
+ height=input_height,
539
+ num_inference_steps=steps,
540
+ num_images_per_prompt=num_samples,
541
+ generator=rng,
542
+ output_type='latent',
543
+ guidance_scale=cfg,
544
+ cross_attention_kwargs={'concat_conds': concat_conds},
545
+ ).images.to(vae.dtype) / vae.config.scaling_factor
546
+ else:
547
+ bg = resize_without_crop(input_bg, input_width, input_height)
548
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
549
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
550
+ latents = i2i_pipe(
551
+ image=bg_latent,
552
+ strength=lowres_denoise,
553
+ prompt_embeds=conds,
554
+ negative_prompt_embeds=unconds,
555
+ width=input_width,
556
+ height=input_height,
557
+ num_inference_steps=int(round(steps / lowres_denoise)),
558
+ num_images_per_prompt=num_samples,
559
+ generator=rng,
560
+ output_type='latent',
561
+ guidance_scale=cfg,
562
+ cross_attention_kwargs={'concat_conds': concat_conds},
563
+ ).images.to(vae.dtype) / vae.config.scaling_factor
564
+
565
+ pixels = vae.decode(latents).sample
566
+ pixels = pytorch2numpy(pixels)
567
+ pixels = [resize_without_crop(
568
+ image=p,
569
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
570
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
571
+ for p in pixels]
572
+
573
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
574
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
575
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
576
+
577
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
578
+
579
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
580
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
581
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
582
+
583
+ latents = i2i_pipe(
584
+ image=latents,
585
+ strength=highres_denoise,
586
+ prompt_embeds=conds,
587
+ negative_prompt_embeds=unconds,
588
+ width=highres_width,
589
+ height=highres_height,
590
+ num_inference_steps=int(round(steps / highres_denoise)),
591
+ num_images_per_prompt=num_samples,
592
+ generator=rng,
593
+ output_type='latent',
594
+ guidance_scale=cfg,
595
+ cross_attention_kwargs={'concat_conds': concat_conds},
596
+ ).images.to(vae.dtype) / vae.config.scaling_factor
597
+
598
+ pixels = vae.decode(latents).sample
599
+ pixels = pytorch2numpy(pixels)
600
+
601
+ # Resize back to input dimensions
602
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
603
+ pixels = np.stack(pixels)
604
+
605
+ return pixels
606
+
607
+ def extract_foreground(image):
608
+ if image is None:
609
+ return None, gr.update(visible=True), gr.update(visible=True)
610
+ logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
611
+ result, rgba = run_rmbg(image)
612
+ logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
613
+ logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
614
+ return result, gr.update(visible=True), gr.update(visible=True)
615
+
616
+ @torch.inference_mode()
617
+ def process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
618
+ clear_memory()
619
+ bg_source = BGSource(bg_source)
620
+
621
+ if bg_source == BGSource.UPLOAD:
622
+ pass
623
+ elif bg_source == BGSource.UPLOAD_FLIP:
624
+ input_bg = np.fliplr(input_bg)
625
+ elif bg_source == BGSource.GREY:
626
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
627
+ elif bg_source == BGSource.LEFT:
628
+ gradient = np.linspace(224, 32, image_width)
629
+ image = np.tile(gradient, (image_height, 1))
630
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
631
+ elif bg_source == BGSource.RIGHT:
632
+ gradient = np.linspace(32, 224, image_width)
633
+ image = np.tile(gradient, (image_height, 1))
634
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
635
+ elif bg_source == BGSource.TOP:
636
+ gradient = np.linspace(224, 32, image_height)[:, None]
637
+ image = np.tile(gradient, (1, image_width))
638
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
639
+ elif bg_source == BGSource.BOTTOM:
640
+ gradient = np.linspace(32, 224, image_height)[:, None]
641
+ image = np.tile(gradient, (1, image_width))
642
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
643
+ else:
644
+ raise 'Wrong background source!'
645
+
646
+ rng = torch.Generator(device=device).manual_seed(seed)
647
+
648
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
649
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
650
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
651
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
652
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
653
+
654
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
655
+
656
+ latents = t2i_pipe(
657
+ prompt_embeds=conds,
658
+ negative_prompt_embeds=unconds,
659
+ width=image_width,
660
+ height=image_height,
661
+ num_inference_steps=steps,
662
+ num_images_per_prompt=num_samples,
663
+ generator=rng,
664
+ output_type='latent',
665
+ guidance_scale=cfg,
666
+ cross_attention_kwargs={'concat_conds': concat_conds},
667
+ ).images.to(vae.dtype) / vae.config.scaling_factor
668
+
669
+ pixels = vae.decode(latents).sample
670
+ pixels = pytorch2numpy(pixels)
671
+ pixels = [resize_without_crop(
672
+ image=p,
673
+ target_width=int(round(image_width * highres_scale / 64.0) * 64),
674
+ target_height=int(round(image_height * highres_scale / 64.0) * 64))
675
+ for p in pixels]
676
+
677
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
678
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
679
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
680
+
681
+ image_height, image_width = latents.shape[2] * 8, latents.shape[3] * 8
682
+ fg = resize_and_center_crop(input_fg, image_width, image_height)
683
+ bg = resize_and_center_crop(input_bg, image_width, image_height)
684
+ concat_conds = numpy2pytorch([fg, bg]).to(device=vae.device, dtype=vae.dtype)
685
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
686
+ concat_conds = torch.cat([c[None, ...] for c in concat_conds], dim=1)
687
+
688
+ latents = i2i_pipe(
689
+ image=latents,
690
+ strength=highres_denoise,
691
+ prompt_embeds=conds,
692
+ negative_prompt_embeds=unconds,
693
+ width=image_width,
694
+ height=image_height,
695
+ num_inference_steps=int(round(steps / highres_denoise)),
696
+ num_images_per_prompt=num_samples,
697
+ generator=rng,
698
+ output_type='latent',
699
+ guidance_scale=cfg,
700
+ cross_attention_kwargs={'concat_conds': concat_conds},
701
+ ).images.to(vae.dtype) / vae.config.scaling_factor
702
+
703
+ pixels = vae.decode(latents).sample
704
+ pixels = pytorch2numpy(pixels, quant=False)
705
+
706
+ clear_memory()
707
+ return pixels, [fg, bg]
708
+
709
+
710
+ @torch.inference_mode()
711
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
712
+ logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
713
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
714
+ logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
715
+ return results
716
+
717
+
718
+
719
+ @torch.inference_mode()
720
+ def process_relight_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source):
721
+ bg_source = BGSource(bg_source)
722
+
723
+ # bg_source = "Use Background Image"
724
+
725
+ # Convert numerical inputs to appropriate types
726
+ image_width = int(image_width)
727
+ image_height = int(image_height)
728
+ num_samples = int(num_samples)
729
+ seed = int(seed)
730
+ steps = int(steps)
731
+ cfg = float(cfg)
732
+ highres_scale = float(highres_scale)
733
+ highres_denoise = float(highres_denoise)
734
+
735
+ if bg_source == BGSource.UPLOAD:
736
+ pass
737
+ elif bg_source == BGSource.UPLOAD_FLIP:
738
+ input_bg = np.fliplr(input_bg)
739
+ elif bg_source == BGSource.GREY:
740
+ input_bg = np.zeros(shape=(image_height, image_width, 3), dtype=np.uint8) + 64
741
+ elif bg_source == BGSource.LEFT:
742
+ gradient = np.linspace(224, 32, image_width)
743
+ image = np.tile(gradient, (image_height, 1))
744
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
745
+ elif bg_source == BGSource.RIGHT:
746
+ gradient = np.linspace(32, 224, image_width)
747
+ image = np.tile(gradient, (image_height, 1))
748
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
749
+ elif bg_source == BGSource.TOP:
750
+ gradient = np.linspace(224, 32, image_height)[:, None]
751
+ image = np.tile(gradient, (1, image_width))
752
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
753
+ elif bg_source == BGSource.BOTTOM:
754
+ gradient = np.linspace(32, 224, image_height)[:, None]
755
+ image = np.tile(gradient, (1, image_width))
756
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
757
+ else:
758
+ raise ValueError('Wrong background source!')
759
+
760
+ input_fg, matting = run_rmbg(input_fg)
761
+ results, extra_images = process_bg(input_fg, input_bg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source)
762
+ results = [(x * 255.0).clip(0, 255).astype(np.uint8) for x in results]
763
+ final_results = results + extra_images
764
+
765
+ # Save the generated images
766
+ save_images(results, prefix="relight")
767
+
768
+ return results
769
+
770
+
771
+ quick_prompts = [
772
+ 'sunshine from window',
773
+ 'neon light, city',
774
+ 'sunset over sea',
775
+ 'golden time',
776
+ 'sci-fi RGB glowing, cyberpunk',
777
+ 'natural lighting',
778
+ 'warm atmosphere, at home, bedroom',
779
+ 'magic lit',
780
+ 'evil, gothic, Yharnam',
781
+ 'light and shadow',
782
+ 'shadow from window',
783
+ 'soft studio lighting',
784
+ 'home atmosphere, cozy bedroom illumination',
785
+ 'neon, Wong Kar-wai, warm'
786
+ ]
787
+ quick_prompts = [[x] for x in quick_prompts]
788
+
789
+
790
+ quick_subjects = [
791
+ 'modern sofa, high quality leather',
792
+ 'elegant dining table, polished wood',
793
+ 'luxurious bed, premium mattress',
794
+ 'minimalist office desk, clean design',
795
+ 'vintage wooden cabinet, antique finish',
796
+ ]
797
+ quick_subjects = [[x] for x in quick_subjects]
798
+
799
+
800
+ class BGSource(Enum):
801
+ UPLOAD = "Use Background Image"
802
+ UPLOAD_FLIP = "Use Flipped Background Image"
803
+ LEFT = "Left Light"
804
+ RIGHT = "Right Light"
805
+ TOP = "Top Light"
806
+ BOTTOM = "Bottom Light"
807
+ GREY = "Ambient"
808
+
809
+ # Add save function
810
+ def save_images(images, prefix="relight"):
811
+ # Create output directory if it doesn't exist
812
+ output_dir = Path("outputs")
813
+ output_dir.mkdir(exist_ok=True)
814
+
815
+ # Create timestamp for unique filenames
816
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
817
+
818
+ saved_paths = []
819
+ for i, img in enumerate(images):
820
+ if isinstance(img, np.ndarray):
821
+ # Convert to PIL Image if numpy array
822
+ img = Image.fromarray(img)
823
+
824
+ # Create filename with timestamp
825
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
826
+ filepath = output_dir / filename
827
+
828
+ # Save image
829
+ img.save(filepath)
830
+
831
+
832
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
833
+ return saved_paths
834
+
835
+
836
+ class MaskMover:
837
+ def __init__(self):
838
+ self.extracted_fg = None
839
+ self.original_fg = None # Store original foreground
840
+
841
+ def set_extracted_fg(self, fg_image):
842
+ """Store the extracted foreground with alpha channel"""
843
+ if isinstance(fg_image, np.ndarray):
844
+ self.extracted_fg = fg_image.copy()
845
+ self.original_fg = fg_image.copy()
846
+ else:
847
+ self.extracted_fg = np.array(fg_image)
848
+ self.original_fg = np.array(fg_image)
849
+ return self.extracted_fg
850
+
851
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
852
+ """Create composite with foreground at specified position"""
853
+ if self.original_fg is None or background is None:
854
+ return background
855
+
856
+ # Convert inputs to PIL Images
857
+ if isinstance(background, np.ndarray):
858
+ bg = Image.fromarray(background).convert('RGBA')
859
+ else:
860
+ bg = background.convert('RGBA')
861
+
862
+ if isinstance(self.original_fg, np.ndarray):
863
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
864
+ else:
865
+ fg = self.original_fg.convert('RGBA')
866
+
867
+ # Scale the foreground size
868
+ new_width = int(fg.width * scale)
869
+ new_height = int(fg.height * scale)
870
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
871
+
872
+ # Center the scaled foreground at the position
873
+ x = int(x_pos - new_width / 2)
874
+ y = int(y_pos - new_height / 2)
875
+
876
+ # Create composite
877
+ result = bg.copy()
878
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
879
+
880
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
881
+
882
+ def get_depth(image):
883
+ if image is None:
884
+ return None
885
+ # Convert from PIL/gradio format to cv2
886
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
887
+ # Get depth map
888
+ depth = model.infer_image(raw_img) # HxW raw depth map
889
+ # Normalize depth for visualization
890
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
891
+ # Convert to RGB for display
892
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
893
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
894
+ return Image.fromarray(depth_colored)
895
+
896
+
897
+ from PIL import Image
898
+
899
+ def compress_image(image):
900
+ # Convert Gradio image (numpy array) to PIL Image
901
+ img = Image.fromarray(image)
902
+
903
+ # Resize image if dimensions are too large
904
+ max_size = 1024 # Maximum dimension size
905
+ if img.width > max_size or img.height > max_size:
906
+ ratio = min(max_size/img.width, max_size/img.height)
907
+ new_size = (int(img.width * ratio), int(img.height * ratio))
908
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
909
+
910
+ quality = 95 # Start with high quality
911
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
912
+
913
+ # Check file size and adjust quality if necessary
914
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
915
+ quality -= 5 # Decrease quality
916
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
917
+ if quality < 20: # Prevent quality from going too low
918
+ break
919
+
920
+ # Convert back to numpy array for Gradio
921
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
922
+ return compressed_img
923
+
924
+ def use_orientation(selected_image:gr.SelectData):
925
+ return selected_image.value['image']['path']
926
+
927
+ @spaces.GPU(duration=60)
928
+ @torch.inference_mode
929
+ def process_image(input_image, input_text):
930
+ """Main processing function for the Gradio interface"""
931
+
932
+ # Initialize configs
933
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
934
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
935
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
936
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
937
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
938
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
939
+
940
+
941
+
942
+ # Initialize DDS client
943
+ config = Config(API_TOKEN)
944
+ client = Client(config)
945
+
946
+ # Process classes from text prompt
947
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
948
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
949
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
950
+
951
+ # Save input image to temp file and get URL
952
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
953
+ cv2.imwrite(tmpfile.name, input_image)
954
+ image_url = client.upload_file(tmpfile.name)
955
+ os.remove(tmpfile.name)
956
+
957
+ # Process detection results
958
+ input_boxes = []
959
+ masks = []
960
+ confidences = []
961
+ class_names = []
962
+ class_ids = []
963
+
964
+ if len(input_text) == 0:
965
+ task = DinoxTask(
966
+ image_url=image_url,
967
+ prompts=[TextPrompt(text="<prompt_free>")],
968
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
969
+ )
970
+
971
+ client.run_task(task)
972
+ predictions = task.result.objects
973
+ classes = [pred.category for pred in predictions]
974
+ classes = list(set(classes))
975
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
976
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
977
+
978
+ for idx, obj in enumerate(predictions):
979
+ input_boxes.append(obj.bbox)
980
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
981
+ confidences.append(obj.score)
982
+ cls_name = obj.category.lower().strip()
983
+ class_names.append(cls_name)
984
+ class_ids.append(class_name_to_id[cls_name])
985
+
986
+ boxes = np.array(input_boxes)
987
+ masks = np.array(masks)
988
+ class_ids = np.array(class_ids)
989
+ labels = [
990
+ f"{class_name} {confidence:.2f}"
991
+ for class_name, confidence
992
+ in zip(class_names, confidences)
993
+ ]
994
+ detections = sv.Detections(
995
+ xyxy=boxes,
996
+ mask=masks.astype(bool),
997
+ class_id=class_ids
998
+ )
999
+
1000
+ box_annotator = sv.BoxAnnotator()
1001
+ label_annotator = sv.LabelAnnotator()
1002
+ mask_annotator = sv.MaskAnnotator()
1003
+
1004
+ annotated_frame = input_image.copy()
1005
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1006
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1007
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1008
+
1009
+ # Create transparent mask for first detected object
1010
+ if len(detections) > 0:
1011
+ # Get first mask
1012
+ first_mask = detections.mask[0]
1013
+
1014
+ # Get original RGB image
1015
+ img = input_image.copy()
1016
+ H, W, C = img.shape
1017
+
1018
+ # Create RGBA image
1019
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1020
+ alpha[first_mask] = 255
1021
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1022
+
1023
+ # Crop to mask bounds to minimize image size
1024
+ y_indices, x_indices = np.where(first_mask)
1025
+ y_min, y_max = y_indices.min(), y_indices.max()
1026
+ x_min, x_max = x_indices.min(), x_indices.max()
1027
+
1028
+ # Crop the RGBA image
1029
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1030
+
1031
+ # Set extracted foreground for mask mover
1032
+ # mask_mover.set_extracted_fg(cropped_rgba)
1033
+
1034
+ return annotated_frame, cropped_rgba, gr.update(visible=False), gr.update(visible=False)
1035
+
1036
+
1037
+ else:
1038
+ # Run DINO-X detection
1039
+ task = DinoxTask(
1040
+ image_url=image_url,
1041
+ prompts=[TextPrompt(text=input_text)],
1042
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
1043
+ )
1044
+
1045
+ client.run_task(task)
1046
+ result = task.result
1047
+ objects = result.objects
1048
+
1049
+
1050
+
1051
+ # for obj in objects:
1052
+ # input_boxes.append(obj.bbox)
1053
+ # confidences.append(obj.score)
1054
+ # cls_name = obj.category.lower().strip()
1055
+ # class_names.append(cls_name)
1056
+ # class_ids.append(class_name_to_id[cls_name])
1057
+
1058
+ # input_boxes = np.array(input_boxes)
1059
+ # class_ids = np.array(class_ids)
1060
+
1061
+ predictions = task.result.objects
1062
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1063
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1064
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1065
+
1066
+ boxes = []
1067
+ masks = []
1068
+ confidences = []
1069
+ class_names = []
1070
+ class_ids = []
1071
+
1072
+ for idx, obj in enumerate(predictions):
1073
+ boxes.append(obj.bbox)
1074
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1075
+ confidences.append(obj.score)
1076
+ cls_name = obj.category.lower().strip()
1077
+ class_names.append(cls_name)
1078
+ class_ids.append(class_name_to_id[cls_name])
1079
+
1080
+ boxes = np.array(boxes)
1081
+ masks = np.array(masks)
1082
+ class_ids = np.array(class_ids)
1083
+ labels = [
1084
+ f"{class_name} {confidence:.2f}"
1085
+ for class_name, confidence
1086
+ in zip(class_names, confidences)
1087
+ ]
1088
+
1089
+ # Initialize SAM2
1090
+ # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
1091
+ # if torch.cuda.get_device_properties(0).major >= 8:
1092
+ # torch.backends.cuda.matmul.allow_tf32 = True
1093
+ # torch.backends.cudnn.allow_tf32 = True
1094
+
1095
+ # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
1096
+ # sam2_predictor = SAM2ImagePredictor(sam2_model)
1097
+ # sam2_predictor.set_image(input_image)
1098
+
1099
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
1100
+
1101
+
1102
+ # Get masks from SAM2
1103
+ # masks, scores, logits = sam2_predictor.predict(
1104
+ # point_coords=None,
1105
+ # point_labels=None,
1106
+ # box=input_boxes,
1107
+ # multimask_output=False,
1108
+ # )
1109
+
1110
+ if masks.ndim == 4:
1111
+ masks = masks.squeeze(1)
1112
+
1113
+ # Create visualization
1114
+ # labels = [f"{class_name} {confidence:.2f}"
1115
+ # for class_name, confidence in zip(class_names, confidences)]
1116
+
1117
+ # detections = sv.Detections(
1118
+ # xyxy=input_boxes,
1119
+ # mask=masks.astype(bool),
1120
+ # class_id=class_ids
1121
+ # )
1122
+
1123
+ detections = sv.Detections(
1124
+ xyxy = boxes,
1125
+ mask = masks.astype(bool),
1126
+ class_id = class_ids,
1127
+ )
1128
+
1129
+ box_annotator = sv.BoxAnnotator()
1130
+ label_annotator = sv.LabelAnnotator()
1131
+ mask_annotator = sv.MaskAnnotator()
1132
+
1133
+ annotated_frame = input_image.copy()
1134
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1135
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1136
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1137
+
1138
+ # Create transparent mask for first detected object
1139
+ if len(detections) > 0:
1140
+ # Get first mask
1141
+ first_mask = detections.mask[0]
1142
+
1143
+ # Get original RGB image
1144
+ img = input_image.copy()
1145
+ H, W, C = img.shape
1146
+
1147
+ # Create RGBA image
1148
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1149
+ alpha[first_mask] = 255
1150
+ rgba = np.dstack((img, alpha)).astype(np.uint8)
1151
+
1152
+ # Crop to mask bounds to minimize image size
1153
+ y_indices, x_indices = np.where(first_mask)
1154
+ y_min, y_max = y_indices.min(), y_indices.max()
1155
+ x_min, x_max = x_indices.min(), x_indices.max()
1156
+
1157
+ # Crop the RGBA image
1158
+ cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1159
+
1160
+ # Set extracted foreground for mask mover
1161
+ # mask_mover.set_extracted_fg(cropped_rgba)
1162
+
1163
+ return annotated_frame, cropped_rgba, gr.update(visible=False), gr.update(visible=False)
1164
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1165
+
1166
+
1167
+ block = gr.Blocks().queue()
1168
+ with block:
1169
+ with gr.Tab("Text"):
1170
+ with gr.Row():
1171
+ gr.Markdown("## Product Placement from Text")
1172
+ with gr.Row():
1173
+ with gr.Column():
1174
+ with gr.Row():
1175
+ input_fg = gr.Image(type="numpy", label="Image", height=480)
1176
+ with gr.Row():
1177
+ with gr.Group():
1178
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1179
+ text_prompt = gr.Textbox(
1180
+ label="Text Prompt",
1181
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1182
+ value=""
1183
+ )
1184
+ extract_button = gr.Button(value="Remove Background")
1185
+ with gr.Row():
1186
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1187
+ extracted_fg = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1188
+ angles_fg = gr.Image(type="pil", label="Converted Foreground", height=480, visible=False)
1189
+
1190
+ with gr.Row():
1191
+ run_button = gr.Button("Generate alternative angles")
1192
+
1193
+ orientation_result = gr.Gallery(
1194
+ label="Result",
1195
+ show_label=False,
1196
+ columns=[3],
1197
+ rows=[2],
1198
+ object_fit="contain",
1199
+ height="auto",
1200
+ allow_preview=False,
1201
+ )
1202
+
1203
+ if orientation_result:
1204
+ selected = gr.Number(visible=True)
1205
+ orientation_result.select(use_orientation, inputs=None, outputs=extracted_fg)
1206
+
1207
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
1208
+ with gr.Group():
1209
+ prompt = gr.Textbox(label="Prompt")
1210
+ bg_source = gr.Radio(choices=[e.value for e in list(BGSource)[2:]],
1211
+ value=BGSource.LEFT.value,
1212
+ label="Lighting Preference (Initial Latent)", type='value')
1213
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
1214
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
1215
+ relight_button = gr.Button(value="Relight")
1216
+
1217
+ with gr.Group(visible=False):
1218
+ with gr.Row():
1219
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1220
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1221
+
1222
+ with gr.Row():
1223
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1224
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1225
+
1226
+ with gr.Accordion("Advanced options", open=False):
1227
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
1228
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01, visible=False)
1229
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
1230
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
1231
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
1232
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality', visible=False)
1233
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality', visible=False)
1234
+ x_slider = gr.Slider(
1235
+ minimum=0,
1236
+ maximum=1000,
1237
+ label="X Position",
1238
+ value=500,
1239
+ visible=False
1240
+ )
1241
+ y_slider = gr.Slider(
1242
+ minimum=0,
1243
+ maximum=1000,
1244
+ label="Y Position",
1245
+ value=500,
1246
+ visible=False
1247
+ )
1248
+ with gr.Column():
1249
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
1250
+ with gr.Row():
1251
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
1252
+ # gr.Examples(
1253
+ # fn=lambda *args: ([args[-1]], None),
1254
+ # examples=db_examples.foreground_conditioned_examples,
1255
+ # inputs=[
1256
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
1257
+ # ],
1258
+ # outputs=[result_gallery, output_bg],
1259
+ # run_on_click=True, examples_per_page=1024
1260
+ # )
1261
+ ips = [extracted_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
1262
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
1263
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1264
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1265
+
1266
+
1267
+ def convert_to_pil(image):
1268
+ try:
1269
+ logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
1270
+ image = image.astype(np.uint8)
1271
+ logging.info(f"Converted image shape: {image.shape}, dtype: {image.dtype}")
1272
+ return image
1273
+ except Exception as e:
1274
+ logging.error(f"Error converting image: {e}")
1275
+ return image
1276
+
1277
+ run_button.click(
1278
+ fn=convert_to_pil,
1279
+ inputs=extracted_fg, # This is already RGBA with removed background
1280
+ outputs=angles_fg
1281
+ ).then(
1282
+ fn=infer,
1283
+ inputs=[
1284
+ text_prompt,
1285
+ extracted_fg, # Already processed RGBA image
1286
+ ],
1287
+ outputs=[orientation_result],
1288
+ )
1289
+
1290
+ find_objects_button.click(
1291
+ fn=process_image,
1292
+ inputs=[input_fg, text_prompt],
1293
+ outputs=[extracted_objects, extracted_fg]
1294
+ )
1295
+
1296
+ extract_button.click(
1297
+ fn=extract_foreground,
1298
+ inputs=[input_fg],
1299
+ outputs=[extracted_fg, x_slider, y_slider]
1300
+ )
1301
+
1302
+ block.launch(server_name='0.0.0.0', share=False)