Ashoka74 commited on
Commit
13fe601
1 Parent(s): 947db12

Create app_merged.py

Browse files
Files changed (1) hide show
  1. app_merged.py +1608 -0
app_merged.py ADDED
@@ -0,0 +1,1608 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import sys
4
+ from typing import Sequence, Mapping, Any, Union
5
+ import torch
6
+ import gradio as gr
7
+ from PIL import Image
8
+ from huggingface_hub import hf_hub_download
9
+ import spaces
10
+
11
+ import spaces
12
+ import argparse
13
+ import random
14
+
15
+ import os
16
+ import math
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ import safetensors.torch as sf
21
+ import datetime
22
+ from pathlib import Path
23
+ from io import BytesIO
24
+
25
+
26
+
27
+ from PIL import Image
28
+ from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline
29
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDIMScheduler, EulerAncestralDiscreteScheduler, DPMSolverMultistepScheduler
30
+ from diffusers.models.attention_processor import AttnProcessor2_0
31
+ from transformers import CLIPTextModel, CLIPTokenizer
32
+ import dds_cloudapi_sdk
33
+ from dds_cloudapi_sdk import Config, Client, TextPrompt
34
+ from dds_cloudapi_sdk.tasks.dinox import DinoxTask
35
+ from dds_cloudapi_sdk.tasks import DetectionTarget
36
+ from dds_cloudapi_sdk.tasks.detection import DetectionTask
37
+ from transformers import AutoModelForImageSegmentation
38
+
39
+
40
+ from enum import Enum
41
+ from torch.hub import download_url_to_file
42
+ import tempfile
43
+
44
+ from sam2.build_sam import build_sam2
45
+
46
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
47
+ import cv2
48
+
49
+ from transformers import AutoModelForImageSegmentation
50
+ from inference_i2mv_sdxl import prepare_pipeline, remove_bg, run_pipeline
51
+ from torchvision import transforms
52
+
53
+
54
+ from typing import Optional
55
+
56
+ from depth_anything_v2.dpt import DepthAnythingV2
57
+
58
+ import httpx
59
+
60
+
61
+ client = httpx.Client(timeout=httpx.Timeout(10.0)) # Set timeout to 10 seconds
62
+ NUM_VIEWS = 6
63
+ HEIGHT = 768
64
+ WIDTH = 768
65
+ MAX_SEED = np.iinfo(np.int32).max
66
+
67
+
68
+
69
+ import supervision as sv
70
+ import torch
71
+ from PIL import Image
72
+
73
+ import logging
74
+
75
+ # Configure logging
76
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
77
+
78
+ transform_image = transforms.Compose(
79
+ [
80
+ transforms.Resize((1024, 1024)),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
83
+ ]
84
+ )
85
+
86
+ # Load
87
+
88
+ # Model paths
89
+ model_path = './models/iclight_sd15_fc.safetensors'
90
+ model_path2 = './checkpoints/depth_anything_v2_vits.pth'
91
+ model_path3 = './checkpoints/sam2_hiera_large.pt'
92
+ model_path4 = './checkpoints/config.json'
93
+ model_path5 = './checkpoints/preprocessor_config.json'
94
+ model_path6 = './configs/sam2_hiera_l.yaml'
95
+ model_path7 = './mvadapter_i2mv_sdxl.safetensors'
96
+
97
+ # Base URL for the repository
98
+ BASE_URL = 'https://huggingface.co/Ashoka74/Placement/resolve/main/'
99
+
100
+ # Model URLs
101
+ model_urls = {
102
+ model_path: 'iclight_sd15_fc.safetensors',
103
+ model_path2: 'depth_anything_v2_vits.pth',
104
+ model_path3: 'sam2_hiera_large.pt',
105
+ model_path4: 'config.json',
106
+ model_path5: 'preprocessor_config.json',
107
+ model_path6: 'sam2_hiera_l.yaml',
108
+ model_path7: 'mvadapter_i2mv_sdxl.safetensors'
109
+ }
110
+
111
+ # Ensure directories exist
112
+ def ensure_directories():
113
+ for path in model_urls.keys():
114
+ os.makedirs(os.path.dirname(path), exist_ok=True)
115
+
116
+ # Download models
117
+ def download_models():
118
+ for local_path, filename in model_urls.items():
119
+ if not os.path.exists(local_path):
120
+ try:
121
+ url = f"{BASE_URL}{filename}"
122
+ print(f"Downloading {filename}")
123
+ download_url_to_file(url, local_path)
124
+ print(f"Successfully downloaded {filename}")
125
+ except Exception as e:
126
+ print(f"Error downloading {filename}: {e}")
127
+
128
+ ensure_directories()
129
+
130
+ download_models()
131
+
132
+
133
+
134
+
135
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Redux-dev", filename="flux1-redux-dev.safetensors", local_dir="models/style_models")
136
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-Depth-dev", filename="flux1-depth-dev.safetensors", local_dir="models/diffusion_models")
137
+ hf_hub_download(repo_id="Comfy-Org/sigclip_vision_384", filename="sigclip_vision_patch14_384.safetensors", local_dir="models/clip_vision")
138
+ hf_hub_download(repo_id="Kijai/DepthAnythingV2-safetensors", filename="depth_anything_v2_vitl_fp32.safetensors", local_dir="models/depthanything")
139
+ hf_hub_download(repo_id="black-forest-labs/FLUX.1-dev", filename="ae.safetensors", local_dir="models/vae/FLUX1")
140
+ hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="clip_l.safetensors", local_dir="models/text_encoders")
141
+ t5_path = hf_hub_download(repo_id="comfyanonymous/flux_text_encoders", filename="t5xxl_fp16.safetensors", local_dir="models/text_encoders/t5")
142
+
143
+
144
+ sd15_name = 'stablediffusionapi/realistic-vision-v51'
145
+ tokenizer = CLIPTokenizer.from_pretrained(sd15_name, subfolder="tokenizer")
146
+ text_encoder = CLIPTextModel.from_pretrained(sd15_name, subfolder="text_encoder")
147
+ vae = AutoencoderKL.from_pretrained(sd15_name, subfolder="vae")
148
+ unet = UNet2DConditionModel.from_pretrained(sd15_name, subfolder="unet")
149
+
150
+ try:
151
+ import xformers
152
+ import xformers.ops
153
+ XFORMERS_AVAILABLE = True
154
+ print("xformers is available - Using memory efficient attention")
155
+ except ImportError:
156
+ XFORMERS_AVAILABLE = False
157
+ print("xformers not available - Using default attention")
158
+
159
+ # Memory optimizations for RTX 2070
160
+ torch.backends.cudnn.benchmark = True
161
+ if torch.cuda.is_available():
162
+ torch.backends.cuda.matmul.allow_tf32 = True
163
+ torch.backends.cudnn.allow_tf32 = True
164
+ # Set a smaller attention slice size for RTX 2070
165
+ torch.backends.cuda.max_split_size_mb = 512
166
+ device = torch.device('cuda')
167
+ else:
168
+ device = torch.device('cpu')
169
+
170
+
171
+ rmbg = AutoModelForImageSegmentation.from_pretrained(
172
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
173
+ )
174
+ rmbg = rmbg.to(device=device, dtype=torch.float32)
175
+
176
+
177
+ model = DepthAnythingV2(encoder='vits', features=64, out_channels=[48, 96, 192, 384])
178
+ model.load_state_dict(torch.load('checkpoints/depth_anything_v2_vits.pth', map_location=device))
179
+ model = model.to(device)
180
+ model.eval()
181
+
182
+
183
+ with torch.no_grad():
184
+ new_conv_in = torch.nn.Conv2d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding)
185
+ new_conv_in.weight.zero_()
186
+ new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight)
187
+ new_conv_in.bias = unet.conv_in.bias
188
+ unet.conv_in = new_conv_in
189
+
190
+
191
+ unet_original_forward = unet.forward
192
+
193
+
194
+ def enable_efficient_attention():
195
+ if XFORMERS_AVAILABLE:
196
+ try:
197
+ # RTX 2070 specific settings
198
+ unet.set_use_memory_efficient_attention_xformers(True)
199
+ vae.set_use_memory_efficient_attention_xformers(True)
200
+ print("Enabled xformers memory efficient attention")
201
+ except Exception as e:
202
+ print(f"Xformers error: {e}")
203
+ print("Falling back to sliced attention")
204
+ # Use sliced attention for RTX 2070
205
+ # unet.set_attention_slice_size(4)
206
+ # vae.set_attention_slice_size(4)
207
+ unet.set_attn_processor(AttnProcessor2_0())
208
+ vae.set_attn_processor(AttnProcessor2_0())
209
+ else:
210
+ # Fallback for when xformers is not available
211
+ print("Using sliced attention")
212
+ # unet.set_attention_slice_size(4)
213
+ # vae.set_attention_slice_size(4)
214
+ unet.set_attn_processor(AttnProcessor2_0())
215
+ vae.set_attn_processor(AttnProcessor2_0())
216
+
217
+ # Add memory clearing function
218
+ def clear_memory():
219
+ if torch.cuda.is_available():
220
+ torch.cuda.empty_cache()
221
+ torch.cuda.synchronize()
222
+
223
+ # Enable efficient attention
224
+ enable_efficient_attention()
225
+
226
+
227
+ def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs):
228
+ c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample)
229
+ c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0)
230
+ new_sample = torch.cat([sample, c_concat], dim=1)
231
+ kwargs['cross_attention_kwargs'] = {}
232
+ return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs)
233
+
234
+
235
+ unet.forward = hooked_unet_forward
236
+
237
+
238
+ sd_offset = sf.load_file(model_path)
239
+ sd_origin = unet.state_dict()
240
+ keys = sd_origin.keys()
241
+ sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()}
242
+ unet.load_state_dict(sd_merged, strict=True)
243
+ del sd_offset, sd_origin, sd_merged, keys
244
+
245
+
246
+ # Device and dtype setup
247
+ device = torch.device('cuda')
248
+ #dtype = torch.float16 # RTX 2070 works well with float16
249
+ dtype = torch.bfloat16
250
+
251
+
252
+ pipe = prepare_pipeline(
253
+ base_model="stabilityai/stable-diffusion-xl-base-1.0",
254
+ vae_model="madebyollin/sdxl-vae-fp16-fix",
255
+ unet_model=None,
256
+ lora_model=None,
257
+ adapter_path="huanngzh/mv-adapter",
258
+ scheduler=None,
259
+ num_views=NUM_VIEWS,
260
+ device=device,
261
+ dtype=dtype,
262
+ )
263
+
264
+
265
+ # Move models to device with consistent dtype
266
+ text_encoder = text_encoder.to(device=device, dtype=dtype)
267
+ vae = vae.to(device=device, dtype=dtype) # Changed from bfloat16 to float16
268
+ unet = unet.to(device=device, dtype=dtype)
269
+ #rmbg = rmbg.to(device=device, dtype=torch.float32) # Keep this as float32
270
+ rmbg = rmbg.to(device)
271
+
272
+ ddim_scheduler = DDIMScheduler(
273
+ num_train_timesteps=1000,
274
+ beta_start=0.00085,
275
+ beta_end=0.012,
276
+ beta_schedule="scaled_linear",
277
+ clip_sample=False,
278
+ set_alpha_to_one=False,
279
+ steps_offset=1,
280
+ )
281
+
282
+ euler_a_scheduler = EulerAncestralDiscreteScheduler(
283
+ num_train_timesteps=1000,
284
+ beta_start=0.00085,
285
+ beta_end=0.012,
286
+ steps_offset=1
287
+ )
288
+
289
+ dpmpp_2m_sde_karras_scheduler = DPMSolverMultistepScheduler(
290
+ num_train_timesteps=1000,
291
+ beta_start=0.00085,
292
+ beta_end=0.012,
293
+ algorithm_type="sde-dpmsolver++",
294
+ use_karras_sigmas=True,
295
+ steps_offset=1
296
+ )
297
+
298
+ # Pipelines
299
+
300
+ t2i_pipe = StableDiffusionPipeline(
301
+ vae=vae,
302
+ text_encoder=text_encoder,
303
+ tokenizer=tokenizer,
304
+ unet=unet,
305
+ scheduler=dpmpp_2m_sde_karras_scheduler,
306
+ safety_checker=None,
307
+ requires_safety_checker=False,
308
+ feature_extractor=None,
309
+ image_encoder=None
310
+ )
311
+
312
+ i2i_pipe = StableDiffusionImg2ImgPipeline(
313
+ vae=vae,
314
+ text_encoder=text_encoder,
315
+ tokenizer=tokenizer,
316
+ unet=unet,
317
+ scheduler=dpmpp_2m_sde_karras_scheduler,
318
+ safety_checker=None,
319
+ requires_safety_checker=False,
320
+ feature_extractor=None,
321
+ image_encoder=None
322
+ )
323
+
324
+
325
+ @torch.inference_mode()
326
+ def encode_prompt_inner(txt: str):
327
+ max_length = tokenizer.model_max_length
328
+ chunk_length = tokenizer.model_max_length - 2
329
+ id_start = tokenizer.bos_token_id
330
+ id_end = tokenizer.eos_token_id
331
+ id_pad = id_end
332
+
333
+ def pad(x, p, i):
334
+ return x[:i] if len(x) >= i else x + [p] * (i - len(x))
335
+
336
+ tokens = tokenizer(txt, truncation=False, add_special_tokens=False)["input_ids"]
337
+ chunks = [[id_start] + tokens[i: i + chunk_length] + [id_end] for i in range(0, len(tokens), chunk_length)]
338
+ chunks = [pad(ck, id_pad, max_length) for ck in chunks]
339
+
340
+ token_ids = torch.tensor(chunks).to(device=device, dtype=torch.int64)
341
+ conds = text_encoder(token_ids).last_hidden_state
342
+
343
+ return conds
344
+
345
+
346
+ @torch.inference_mode()
347
+ def encode_prompt_pair(positive_prompt, negative_prompt):
348
+ c = encode_prompt_inner(positive_prompt)
349
+ uc = encode_prompt_inner(negative_prompt)
350
+
351
+ c_len = float(len(c))
352
+ uc_len = float(len(uc))
353
+ max_count = max(c_len, uc_len)
354
+ c_repeat = int(math.ceil(max_count / c_len))
355
+ uc_repeat = int(math.ceil(max_count / uc_len))
356
+ max_chunk = max(len(c), len(uc))
357
+
358
+ c = torch.cat([c] * c_repeat, dim=0)[:max_chunk]
359
+ uc = torch.cat([uc] * uc_repeat, dim=0)[:max_chunk]
360
+
361
+ c = torch.cat([p[None, ...] for p in c], dim=1)
362
+ uc = torch.cat([p[None, ...] for p in uc], dim=1)
363
+
364
+ return c, uc
365
+
366
+ # @spaces.GPU(duration=60)
367
+ # @torch.inference_mode()
368
+ @spaces.GPU(duration=60)
369
+ @torch.inference_mode()
370
+ def infer(
371
+ prompt,
372
+ image, # This is already RGBA with background removed
373
+ do_rembg=True,
374
+ seed=42,
375
+ randomize_seed=False,
376
+ guidance_scale=3.0,
377
+ num_inference_steps=50,
378
+ reference_conditioning_scale=1.0,
379
+ negative_prompt="watermark, ugly, deformed, noisy, blurry, low contrast",
380
+ progress=gr.Progress(track_tqdm=True),
381
+ ):
382
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
383
+
384
+ # Convert input to PIL if needed
385
+ if isinstance(image, np.ndarray):
386
+ if image.shape[-1] == 4: # RGBA
387
+ image = Image.fromarray(image, 'RGBA')
388
+ else: # RGB
389
+ image = Image.fromarray(image, 'RGB')
390
+
391
+ #logging.info(f"Converted to PIL Image mode: {image.mode}")
392
+
393
+ # No need for remove_bg_fn since image is already processed
394
+ remove_bg_fn = None
395
+
396
+ if randomize_seed:
397
+ seed = random.randint(0, MAX_SEED)
398
+
399
+ images, preprocessed_image = run_pipeline(
400
+ pipe,
401
+ num_views=NUM_VIEWS,
402
+ text=prompt,
403
+ image=image,
404
+ height=HEIGHT,
405
+ width=WIDTH,
406
+ num_inference_steps=num_inference_steps,
407
+ guidance_scale=guidance_scale,
408
+ seed=seed,
409
+ remove_bg_fn=remove_bg_fn, # Set to None since preprocessing is done
410
+ reference_conditioning_scale=reference_conditioning_scale,
411
+ negative_prompt=negative_prompt,
412
+ device=device,
413
+ )
414
+
415
+ # logging.info(f"Output images shape: {[img.shape for img in images]}")
416
+ # logging.info(f"Preprocessed image shape: {preprocessed_image.shape if preprocessed_image is not None else None}")
417
+ return images
418
+
419
+
420
+ @spaces.GPU(duration=60)
421
+ @torch.inference_mode()
422
+ def pytorch2numpy(imgs, quant=True):
423
+ results = []
424
+ for x in imgs:
425
+ y = x.movedim(0, -1)
426
+
427
+ if quant:
428
+ y = y * 127.5 + 127.5
429
+ y = y.detach().float().cpu().numpy().clip(0, 255).astype(np.uint8)
430
+ else:
431
+ y = y * 0.5 + 0.5
432
+ y = y.detach().float().cpu().numpy().clip(0, 1).astype(np.float32)
433
+
434
+ results.append(y)
435
+ return results
436
+
437
+ @spaces.GPU(duration=60)
438
+ @torch.inference_mode()
439
+ def numpy2pytorch(imgs):
440
+ h = torch.from_numpy(np.stack(imgs, axis=0)).float() / 127.0 - 1.0 # so that 127 must be strictly 0.0
441
+ h = h.movedim(-1, 1)
442
+ return h
443
+
444
+
445
+ def resize_and_center_crop(image, target_width, target_height):
446
+ pil_image = Image.fromarray(image)
447
+ original_width, original_height = pil_image.size
448
+ scale_factor = max(target_width / original_width, target_height / original_height)
449
+ resized_width = int(round(original_width * scale_factor))
450
+ resized_height = int(round(original_height * scale_factor))
451
+ resized_image = pil_image.resize((resized_width, resized_height), Image.LANCZOS)
452
+ left = (resized_width - target_width) / 2
453
+ top = (resized_height - target_height) / 2
454
+ right = (resized_width + target_width) / 2
455
+ bottom = (resized_height + target_height) / 2
456
+ cropped_image = resized_image.crop((left, top, right, bottom))
457
+ return np.array(cropped_image)
458
+
459
+
460
+ def resize_without_crop(image, target_width, target_height):
461
+ pil_image = Image.fromarray(image)
462
+ resized_image = pil_image.resize((target_width, target_height), Image.LANCZOS)
463
+ return np.array(resized_image)
464
+
465
+ # @spaces.GPU(duration=60)
466
+ # @torch.inference_mode()
467
+ # def run_rmbg(img, sigma=0.0):
468
+ # # Convert RGBA to RGB if needed
469
+ # if img.shape[-1] == 4:
470
+ # # Use white background for alpha composition
471
+ # alpha = img[..., 3:] / 255.0
472
+ # rgb = img[..., :3]
473
+ # white_bg = np.ones_like(rgb) * 255
474
+ # img = (rgb * alpha + white_bg * (1 - alpha)).astype(np.uint8)
475
+
476
+ # H, W, C = img.shape
477
+ # assert C == 3
478
+ # k = (256.0 / float(H * W)) ** 0.5
479
+ # feed = resize_without_crop(img, int(64 * round(W * k)), int(64 * round(H * k)))
480
+ # feed = numpy2pytorch([feed]).to(device=device, dtype=torch.float32)
481
+ # alpha = rmbg(feed)[0][0]
482
+ # alpha = torch.nn.functional.interpolate(alpha, size=(H, W), mode="bilinear")
483
+ # alpha = alpha.movedim(1, -1)[0]
484
+ # alpha = alpha.detach().float().cpu().numpy().clip(0, 1)
485
+
486
+ # # Create RGBA image
487
+ # rgba = np.dstack((img, alpha * 255)).astype(np.uint8)
488
+ # result = 127 + (img.astype(np.float32) - 127 + sigma) * alpha
489
+ # return result.clip(0, 255).astype(np.uint8), rgba
490
+
491
+ @spaces.GPU
492
+ @torch.inference_mode()
493
+ def run_rmbg(image):
494
+ image_size = image.size
495
+ input_images = transform_image(image).unsqueeze(0).to("cuda")
496
+ # Prediction
497
+ with torch.no_grad():
498
+ preds = rmbg(input_images)[-1].sigmoid().cpu()
499
+ pred = preds[0].squeeze()
500
+ pred_pil = transforms.ToPILImage()(pred)
501
+ mask = pred_pil.resize(image_size)
502
+ image.putalpha(mask)
503
+ return image
504
+
505
+
506
+
507
+ def preprocess_image(image: Image.Image, height=768, width=768):
508
+ image = np.array(image)
509
+ alpha = image[..., 3] > 0
510
+ H, W = alpha.shape
511
+ # get the bounding box of alpha
512
+ y, x = np.where(alpha)
513
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
514
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
515
+ image_center = image[y0:y1, x0:x1]
516
+ # resize the longer side to H * 0.9
517
+ H, W, _ = image_center.shape
518
+ if H > W:
519
+ W = int(W * (height * 0.9) / H)
520
+ H = int(height * 0.9)
521
+ else:
522
+ H = int(H * (width * 0.9) / W)
523
+ W = int(width * 0.9)
524
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
525
+ # pad to H, W
526
+ start_h = (height - H) // 2
527
+ start_w = (width - W) // 2
528
+ image = np.zeros((height, width, 4), dtype=np.uint8)
529
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
530
+ image = image.astype(np.float32) / 255.0
531
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
532
+ image = (image * 255).clip(0, 255).astype(np.uint8)
533
+ image = Image.fromarray(image)
534
+ return image
535
+
536
+
537
+ @spaces.GPU(duration=60)
538
+ @torch.inference_mode()
539
+ def process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
540
+ clear_memory()
541
+
542
+ # Get input dimensions
543
+ input_height, input_width = input_fg.shape[:2]
544
+
545
+ bg_source = BGSource(bg_source)
546
+
547
+
548
+ if bg_source == BGSource.UPLOAD:
549
+ pass
550
+ elif bg_source == BGSource.UPLOAD_FLIP:
551
+ input_bg = np.fliplr(input_bg)
552
+ if bg_source == BGSource.GREY:
553
+ input_bg = np.zeros(shape=(input_height, input_width, 3), dtype=np.uint8) + 64
554
+ elif bg_source == BGSource.LEFT:
555
+ gradient = np.linspace(255, 0, input_width)
556
+ image = np.tile(gradient, (input_height, 1))
557
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
558
+ elif bg_source == BGSource.RIGHT:
559
+ gradient = np.linspace(0, 255, input_width)
560
+ image = np.tile(gradient, (input_height, 1))
561
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
562
+ elif bg_source == BGSource.TOP:
563
+ gradient = np.linspace(255, 0, input_height)[:, None]
564
+ image = np.tile(gradient, (1, input_width))
565
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
566
+ elif bg_source == BGSource.BOTTOM:
567
+ gradient = np.linspace(0, 255, input_height)[:, None]
568
+ image = np.tile(gradient, (1, input_width))
569
+ input_bg = np.stack((image,) * 3, axis=-1).astype(np.uint8)
570
+ else:
571
+ raise 'Wrong initial latent!'
572
+
573
+ rng = torch.Generator(device=device).manual_seed(int(seed))
574
+
575
+ # Use input dimensions directly
576
+ fg = resize_without_crop(input_fg, input_width, input_height)
577
+
578
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
579
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
580
+
581
+ conds, unconds = encode_prompt_pair(positive_prompt=prompt + ', ' + a_prompt, negative_prompt=n_prompt)
582
+
583
+ if input_bg is None:
584
+ latents = t2i_pipe(
585
+ prompt_embeds=conds,
586
+ negative_prompt_embeds=unconds,
587
+ width=input_width,
588
+ height=input_height,
589
+ num_inference_steps=steps,
590
+ num_images_per_prompt=num_samples,
591
+ generator=rng,
592
+ output_type='latent',
593
+ guidance_scale=cfg,
594
+ cross_attention_kwargs={'concat_conds': concat_conds},
595
+ ).images.to(vae.dtype) / vae.config.scaling_factor
596
+ else:
597
+ bg = resize_without_crop(input_bg, input_width, input_height)
598
+ bg_latent = numpy2pytorch([bg]).to(device=vae.device, dtype=vae.dtype)
599
+ bg_latent = vae.encode(bg_latent).latent_dist.mode() * vae.config.scaling_factor
600
+ latents = i2i_pipe(
601
+ image=bg_latent,
602
+ strength=lowres_denoise,
603
+ prompt_embeds=conds,
604
+ negative_prompt_embeds=unconds,
605
+ width=input_width,
606
+ height=input_height,
607
+ num_inference_steps=int(round(steps / lowres_denoise)),
608
+ num_images_per_prompt=num_samples,
609
+ generator=rng,
610
+ output_type='latent',
611
+ guidance_scale=cfg,
612
+ cross_attention_kwargs={'concat_conds': concat_conds},
613
+ ).images.to(vae.dtype) / vae.config.scaling_factor
614
+
615
+ pixels = vae.decode(latents).sample
616
+ pixels = pytorch2numpy(pixels)
617
+ pixels = [resize_without_crop(
618
+ image=p,
619
+ target_width=int(round(input_width * highres_scale / 64.0) * 64),
620
+ target_height=int(round(input_height * highres_scale / 64.0) * 64))
621
+ for p in pixels]
622
+
623
+ pixels = numpy2pytorch(pixels).to(device=vae.device, dtype=vae.dtype)
624
+ latents = vae.encode(pixels).latent_dist.mode() * vae.config.scaling_factor
625
+ latents = latents.to(device=unet.device, dtype=unet.dtype)
626
+
627
+ highres_height, highres_width = latents.shape[2] * 8, latents.shape[3] * 8
628
+
629
+ fg = resize_without_crop(input_fg, highres_width, highres_height)
630
+ concat_conds = numpy2pytorch([fg]).to(device=vae.device, dtype=vae.dtype)
631
+ concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
632
+
633
+ latents = i2i_pipe(
634
+ image=latents,
635
+ strength=highres_denoise,
636
+ prompt_embeds=conds,
637
+ negative_prompt_embeds=unconds,
638
+ width=highres_width,
639
+ height=highres_height,
640
+ num_inference_steps=int(round(steps / highres_denoise)),
641
+ num_images_per_prompt=num_samples,
642
+ generator=rng,
643
+ output_type='latent',
644
+ guidance_scale=cfg,
645
+ cross_attention_kwargs={'concat_conds': concat_conds},
646
+ ).images.to(vae.dtype) / vae.config.scaling_factor
647
+
648
+ pixels = vae.decode(latents).sample
649
+ pixels = pytorch2numpy(pixels)
650
+
651
+ # Resize back to input dimensions
652
+ pixels = [resize_without_crop(p, input_width, input_height) for p in pixels]
653
+ pixels = np.stack(pixels)
654
+
655
+ return pixels
656
+
657
+ def extract_foreground(image):
658
+ if image is None:
659
+ return None, gr.update(visible=True), gr.update(visible=True)
660
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
661
+ #result, rgba = run_rmbg(image)
662
+ result = run_rmbg(image)
663
+ result = preprocess_image(result)
664
+ #logging.info(f"Result shape: {result.shape}, dtype: {result.dtype}")
665
+ #logging.info(f"RGBA shape: {rgba.shape}, dtype: {rgba.dtype}")
666
+ return result, gr.update(visible=True), gr.update(visible=True)
667
+
668
+ def update_extracted_fg_height(selected_image: gr.SelectData):
669
+ if selected_image:
670
+ # Get the height of the selected image
671
+ height = selected_image.value['image']['shape'][0] # Assuming the image is in numpy format
672
+ return gr.update(height=height) # Update the height of extracted_fg
673
+ return gr.update(height=480) # Default height if no image is selected
674
+
675
+
676
+
677
+ @torch.inference_mode()
678
+ def process_relight(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source):
679
+ # Convert input foreground from PIL to NumPy array if it's in PIL format
680
+ if isinstance(input_fg, Image.Image):
681
+ input_fg = np.array(input_fg)
682
+ logging.info(f"Input foreground shape: {input_fg.shape}, dtype: {input_fg.dtype}")
683
+ results = process(input_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source)
684
+ logging.info(f"Results shape: {results.shape}, dtype: {results.dtype}")
685
+ return results
686
+
687
+
688
+ quick_prompts = [
689
+ 'sunshine from window',
690
+ 'golden time',
691
+ 'natural lighting',
692
+ 'warm atmosphere, at home, bedroom',
693
+ 'shadow from window',
694
+ 'soft studio lighting',
695
+ 'home atmosphere, cozy bedroom illumination',
696
+ ]
697
+ quick_prompts = [[x] for x in quick_prompts]
698
+
699
+
700
+ quick_subjects = [
701
+ 'modern sofa, high quality leather',
702
+ 'elegant dining table, polished wood',
703
+ 'luxurious bed, premium mattress',
704
+ 'minimalist office desk, clean design',
705
+ 'vintage wooden cabinet, antique finish',
706
+ ]
707
+ quick_subjects = [[x] for x in quick_subjects]
708
+
709
+
710
+ class BGSource(Enum):
711
+ UPLOAD = "Use Background Image"
712
+ UPLOAD_FLIP = "Use Flipped Background Image"
713
+ NONE = "None"
714
+ LEFT = "Left Light"
715
+ RIGHT = "Right Light"
716
+ TOP = "Top Light"
717
+ BOTTOM = "Bottom Light"
718
+ GREY = "Ambient"
719
+
720
+ # Add save function
721
+ def save_images(images, prefix="relight"):
722
+ # Create output directory if it doesn't exist
723
+ output_dir = Path("outputs")
724
+ output_dir.mkdir(exist_ok=True)
725
+
726
+ # Create timestamp for unique filenames
727
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
728
+
729
+ saved_paths = []
730
+ for i, img in enumerate(images):
731
+ if isinstance(img, np.ndarray):
732
+ # Convert to PIL Image if numpy array
733
+ img = Image.fromarray(img)
734
+
735
+ # Create filename with timestamp
736
+ filename = f"{prefix}_{timestamp}_{i+1}.png"
737
+ filepath = output_dir / filename
738
+
739
+ # Save image
740
+ img.save(filepath)
741
+
742
+
743
+ # print(f"Saved {len(saved_paths)} images to {output_dir}")
744
+ return saved_paths
745
+
746
+
747
+ class MaskMover:
748
+ def __init__(self):
749
+ self.extracted_fg = None
750
+ self.original_fg = None # Store original foreground
751
+
752
+ def set_extracted_fg(self, fg_image):
753
+ """Store the extracted foreground with alpha channel"""
754
+ if isinstance(fg_image, np.ndarray):
755
+ self.extracted_fg = fg_image.copy()
756
+ self.original_fg = fg_image.copy()
757
+ else:
758
+ self.extracted_fg = np.array(fg_image)
759
+ self.original_fg = np.array(fg_image)
760
+ return self.extracted_fg
761
+
762
+ def create_composite(self, background, x_pos, y_pos, scale=1.0):
763
+ """Create composite with foreground at specified position"""
764
+ if self.original_fg is None or background is None:
765
+ return background
766
+
767
+ # Convert inputs to PIL Images
768
+ if isinstance(background, np.ndarray):
769
+ bg = Image.fromarray(background).convert('RGBA')
770
+ else:
771
+ bg = background.convert('RGBA')
772
+
773
+ if isinstance(self.original_fg, np.ndarray):
774
+ fg = Image.fromarray(self.original_fg).convert('RGBA')
775
+ else:
776
+ fg = self.original_fg.convert('RGBA')
777
+
778
+ # Scale the foreground size
779
+ new_width = int(fg.width * scale)
780
+ new_height = int(fg.height * scale)
781
+ fg = fg.resize((new_width, new_height), Image.LANCZOS)
782
+
783
+ # Center the scaled foreground at the position
784
+ x = int(x_pos - new_width / 2)
785
+ y = int(y_pos - new_height / 2)
786
+
787
+ # Create composite
788
+ result = bg.copy()
789
+ result.paste(fg, (x, y), fg) # Use fg as the mask (requires fg to be in 'RGBA' mode)
790
+
791
+ return np.array(result.convert('RGB')) # Convert back to 'RGB' if needed
792
+
793
+ def get_depth(image):
794
+ if image is None:
795
+ return None
796
+ # Convert from PIL/gradio format to cv2
797
+ raw_img = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
798
+ # Get depth map
799
+ depth = model.infer_image(raw_img) # HxW raw depth map
800
+ # Normalize depth for visualization
801
+ depth = ((depth - depth.min()) / (depth.max() - depth.min()) * 255).astype(np.uint8)
802
+ # Convert to RGB for display
803
+ depth_colored = cv2.applyColorMap(depth, cv2.COLORMAP_INFERNO)
804
+ depth_colored = cv2.cvtColor(depth_colored, cv2.COLOR_BGR2RGB)
805
+ return Image.fromarray(depth_colored)
806
+
807
+
808
+ from PIL import Image
809
+
810
+ def compress_image(image):
811
+ # Convert Gradio image (numpy array) to PIL Image
812
+ img = Image.fromarray(image)
813
+
814
+ # Resize image if dimensions are too large
815
+ max_size = 1024 # Maximum dimension size
816
+ if img.width > max_size or img.height > max_size:
817
+ ratio = min(max_size/img.width, max_size/img.height)
818
+ new_size = (int(img.width * ratio), int(img.height * ratio))
819
+ img = img.resize(new_size, Image.Resampling.LANCZOS)
820
+
821
+ quality = 95 # Start with high quality
822
+ img.save("compressed_image.jpg", "JPEG", quality=quality) # Initial save
823
+
824
+ # Check file size and adjust quality if necessary
825
+ while os.path.getsize("compressed_image.jpg") > 100 * 1024: # 100KB limit
826
+ quality -= 5 # Decrease quality
827
+ img.save("compressed_image.jpg", "JPEG", quality=quality)
828
+ if quality < 20: # Prevent quality from going too low
829
+ break
830
+
831
+ # Convert back to numpy array for Gradio
832
+ compressed_img = np.array(Image.open("compressed_image.jpg"))
833
+ return compressed_img
834
+
835
+ def use_orientation(selected_image:gr.SelectData):
836
+ return selected_image.value['image']['path']
837
+
838
+
839
+ @spaces.GPU(duration=60)
840
+ @torch.inference_mode
841
+ def process_image(input_image, input_text):
842
+ """Main processing function for the Gradio interface"""
843
+
844
+
845
+
846
+ if isinstance(input_image, Image.Image):
847
+ input_image = np.array(input_image)
848
+
849
+ # Initialize configs
850
+ API_TOKEN = "9c8c865e10ec1821bea79d9fa9dc8720"
851
+ SAM2_CHECKPOINT = "./checkpoints/sam2_hiera_large.pt"
852
+ SAM2_MODEL_CONFIG = os.path.join(os.path.dirname(os.path.abspath(__file__)), "configs/sam2_hiera_l.yaml")
853
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
854
+ OUTPUT_DIR = Path("outputs/grounded_sam2_dinox_demo")
855
+ OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
856
+
857
+ HEIGHT = 768
858
+ WIDTH = 768
859
+
860
+
861
+ # Initialize DDS client
862
+ config = Config(API_TOKEN)
863
+ client = Client(config)
864
+
865
+ # Process classes from text prompt
866
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
867
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
868
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
869
+
870
+
871
+
872
+ # Save input image to temp file and get URL
873
+ with tempfile.NamedTemporaryFile(suffix='.jpg', delete=False) as tmpfile:
874
+ cv2.imwrite(tmpfile.name, input_image)
875
+ image_url = client.upload_file(tmpfile.name)
876
+ os.remove(tmpfile.name)
877
+
878
+ # Process detection results
879
+ input_boxes = []
880
+ masks = []
881
+ confidences = []
882
+ class_names = []
883
+ class_ids = []
884
+
885
+ if len(input_text) == 0:
886
+ task = DinoxTask(
887
+ image_url=image_url,
888
+ prompts=[TextPrompt(text="<prompt_free>")],
889
+ # targets=[DetectionTarget.BBox, DetectionTarget.Mask]
890
+ )
891
+
892
+ client.run_task(task)
893
+ predictions = task.result.objects
894
+ classes = [pred.category for pred in predictions]
895
+ classes = list(set(classes))
896
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
897
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
898
+
899
+ for idx, obj in enumerate(predictions):
900
+ input_boxes.append(obj.bbox)
901
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
902
+ confidences.append(obj.score)
903
+ cls_name = obj.category.lower().strip()
904
+ class_names.append(cls_name)
905
+ class_ids.append(class_name_to_id[cls_name])
906
+
907
+ boxes = np.array(input_boxes)
908
+ masks = np.array(masks)
909
+ class_ids = np.array(class_ids)
910
+ labels = [
911
+ f"{class_name} {confidence:.2f}"
912
+ for class_name, confidence
913
+ in zip(class_names, confidences)
914
+ ]
915
+ detections = sv.Detections(
916
+ xyxy=boxes,
917
+ mask=masks.astype(bool),
918
+ class_id=class_ids
919
+ )
920
+
921
+ box_annotator = sv.BoxAnnotator()
922
+ label_annotator = sv.LabelAnnotator()
923
+ mask_annotator = sv.MaskAnnotator()
924
+
925
+ annotated_frame = input_image.copy()
926
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
927
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
928
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
929
+
930
+ # Create transparent mask for first detected object
931
+ if len(detections) > 0:
932
+ # Get first mask
933
+ first_mask = detections.mask[0]
934
+
935
+ # Get original RGB image
936
+ img = input_image.copy()
937
+
938
+ H, W, C = img.shape
939
+
940
+ # Create RGBA image
941
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
942
+
943
+ alpha[first_mask] = 255
944
+
945
+ # rgba = np.dstack((img, alpha)).astype(np.uint8)
946
+
947
+ # Crop to mask bounds to minimize image size
948
+ # y_indices, x_indices = np.where(first_mask)
949
+ # y_min, y_max = y_indices.min(), y_indices.max()
950
+ # x_min, x_max = x_indices.min(), x_indices.max()
951
+
952
+ # Crop the RGBA image
953
+ # cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
954
+
955
+ # Set extracted foreground for mask mover
956
+ # mask_mover.set_extracted_fg(cropped_rgba)
957
+
958
+ # alpha = img[..., 3] > 0
959
+ H, W = alpha.shape
960
+ # get the bounding box of alpha
961
+ y, x = np.where(alpha > 0)
962
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
963
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
964
+
965
+ image_center = img[y0:y1, x0:x1]
966
+ # resize the longer side to H * 0.9
967
+ H, W, _ = image_center.shape
968
+ if H > W:
969
+ W = int(W * (HEIGHT * 0.9) / H)
970
+ H = int(HEIGHT * 0.9)
971
+ else:
972
+ H = int(H * (WIDTH * 0.9) / W)
973
+ W = int(WIDTH * 0.9)
974
+
975
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
976
+ # pad to H, W
977
+ start_h = (HEIGHT - H) // 2
978
+ start_w = (WIDTH - W) // 2
979
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
980
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
981
+ image = image.astype(np.float32) / 255.0
982
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
983
+ image = (image * 255).clip(0, 255).astype(np.uint8)
984
+ image = Image.fromarray(image)
985
+
986
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
987
+
988
+
989
+ else:
990
+ # Run DINO-X detection
991
+ task = DinoxTask(
992
+ image_url=image_url,
993
+ prompts=[TextPrompt(text=input_text)],
994
+ targets=[DetectionTarget.BBox, DetectionTarget.Mask]
995
+ )
996
+
997
+ client.run_task(task)
998
+ result = task.result
999
+ objects = result.objects
1000
+
1001
+
1002
+
1003
+ # for obj in objects:
1004
+ # input_boxes.append(obj.bbox)
1005
+ # confidences.append(obj.score)
1006
+ # cls_name = obj.category.lower().strip()
1007
+ # class_names.append(cls_name)
1008
+ # class_ids.append(class_name_to_id[cls_name])
1009
+
1010
+ # input_boxes = np.array(input_boxes)
1011
+ # class_ids = np.array(class_ids)
1012
+
1013
+ predictions = task.result.objects
1014
+ classes = [x.strip().lower() for x in input_text.split('.') if x]
1015
+ class_name_to_id = {name: id for id, name in enumerate(classes)}
1016
+ class_id_to_name = {id: name for name, id in class_name_to_id.items()}
1017
+
1018
+ boxes = []
1019
+ masks = []
1020
+ confidences = []
1021
+ class_names = []
1022
+ class_ids = []
1023
+
1024
+ for idx, obj in enumerate(predictions):
1025
+ boxes.append(obj.bbox)
1026
+ masks.append(DetectionTask.rle2mask(DetectionTask.string2rle(obj.mask.counts), obj.mask.size)) # convert mask to np.array using DDS API
1027
+ confidences.append(obj.score)
1028
+ cls_name = obj.category.lower().strip()
1029
+ class_names.append(cls_name)
1030
+ class_ids.append(class_name_to_id[cls_name])
1031
+
1032
+ boxes = np.array(boxes)
1033
+ masks = np.array(masks)
1034
+ class_ids = np.array(class_ids)
1035
+ labels = [
1036
+ f"{class_name} {confidence:.2f}"
1037
+ for class_name, confidence
1038
+ in zip(class_names, confidences)
1039
+ ]
1040
+
1041
+ # Initialize SAM2
1042
+ # torch.autocast(device_type=DEVICE, dtype=torch.bfloat16).__enter__()
1043
+ # if torch.cuda.get_device_properties(0).major >= 8:
1044
+ # torch.backends.cuda.matmul.allow_tf32 = True
1045
+ # torch.backends.cudnn.allow_tf32 = True
1046
+
1047
+ # sam2_model = build_sam2(SAM2_MODEL_CONFIG, SAM2_CHECKPOINT, device=DEVICE)
1048
+ # sam2_predictor = SAM2ImagePredictor(sam2_model)
1049
+ # sam2_predictor.set_image(input_image)
1050
+
1051
+ # sam2_predictor = run_sam_inference(SAM_IMAGE_MODEL, input_image, detections)
1052
+
1053
+
1054
+ # Get masks from SAM2
1055
+ # masks, scores, logits = sam2_predictor.predict(
1056
+ # point_coords=None,
1057
+ # point_labels=None,
1058
+ # box=input_boxes,
1059
+ # multimask_output=False,
1060
+ # )
1061
+
1062
+ if masks.ndim == 4:
1063
+ masks = masks.squeeze(1)
1064
+
1065
+ # Create visualization
1066
+ # labels = [f"{class_name} {confidence:.2f}"
1067
+ # for class_name, confidence in zip(class_names, confidences)]
1068
+
1069
+ # detections = sv.Detections(
1070
+ # xyxy=input_boxes,
1071
+ # mask=masks.astype(bool),
1072
+ # class_id=class_ids
1073
+ # )
1074
+
1075
+ detections = sv.Detections(
1076
+ xyxy = boxes,
1077
+ mask = masks.astype(bool),
1078
+ class_id = class_ids,
1079
+ )
1080
+
1081
+ box_annotator = sv.BoxAnnotator()
1082
+ label_annotator = sv.LabelAnnotator()
1083
+ mask_annotator = sv.MaskAnnotator()
1084
+
1085
+ annotated_frame = input_image.copy()
1086
+ annotated_frame = box_annotator.annotate(scene=annotated_frame, detections=detections)
1087
+ annotated_frame = label_annotator.annotate(scene=annotated_frame, detections=detections, labels=labels)
1088
+ annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
1089
+
1090
+ # Create transparent mask for first detected object
1091
+ if len(detections) > 0:
1092
+ # Get first mask
1093
+ first_mask = detections.mask[0]
1094
+
1095
+ # Get original RGB image
1096
+ img = input_image.copy()
1097
+ H, W, C = img.shape
1098
+
1099
+ first_mask = detections.mask[0]
1100
+
1101
+
1102
+
1103
+ # Create RGBA image
1104
+ alpha = np.zeros((H, W, 1), dtype=np.uint8)
1105
+
1106
+ alpha[first_mask] = 255
1107
+
1108
+ # rgba = np.dstack((img, alpha)).astype(np.uint8)
1109
+
1110
+ # Crop to mask bounds to minimize image size
1111
+ # y_indices, x_indices = np.where(first_mask)
1112
+ # y_min, y_max = y_indices.min(), y_indices.max()
1113
+ # x_min, x_max = x_indices.min(), x_indices.max()
1114
+
1115
+ # Crop the RGBA image
1116
+ # cropped_rgba = rgba[y_min:y_max+1, x_min:x_max+1]
1117
+
1118
+ # Set extracted foreground for mask mover
1119
+ # mask_mover.set_extracted_fg(cropped_rgba)
1120
+
1121
+ # alpha = img[..., 3] > 0
1122
+ H, W = alpha.shape
1123
+ # get the bounding box of alpha
1124
+ y, x = np.where(alpha > 0)
1125
+ y0, y1 = max(y.min() - 1, 0), min(y.max() + 1, H)
1126
+ x0, x1 = max(x.min() - 1, 0), min(x.max() + 1, W)
1127
+
1128
+ image_center = img[y0:y1, x0:x1]
1129
+ # resize the longer side to H * 0.9
1130
+ H, W, _ = image_center.shape
1131
+ if H > W:
1132
+ W = int(W * (HEIGHT * 0.9) / H)
1133
+ H = int(HEIGHT * 0.9)
1134
+ else:
1135
+ H = int(H * (WIDTH * 0.9) / W)
1136
+ W = int(WIDTH * 0.9)
1137
+
1138
+ image_center = np.array(Image.fromarray(image_center).resize((W, H)))
1139
+ # pad to H, W
1140
+ start_h = (HEIGHT - H) // 2
1141
+ start_w = (WIDTH - W) // 2
1142
+ image = np.zeros((HEIGHT, WIDTH, 4), dtype=np.uint8)
1143
+ image[start_h : start_h + H, start_w : start_w + W] = image_center
1144
+ image = image.astype(np.float32) / 255.0
1145
+ image = image[:, :, :3] * image[:, :, 3:4] + (1 - image[:, :, 3:4]) * 0.5
1146
+ image = (image * 255).clip(0, 255).astype(np.uint8)
1147
+ image = Image.fromarray(image)
1148
+
1149
+ return annotated_frame, image, gr.update(visible=False), gr.update(visible=False)
1150
+ return annotated_frame, None, gr.update(visible=False), gr.update(visible=False)
1151
+
1152
+
1153
+
1154
+
1155
+ # Import all the necessary functions from the original script
1156
+ def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
1157
+ try:
1158
+ return obj[index]
1159
+ except KeyError:
1160
+ return obj["result"][index]
1161
+
1162
+ # Add all the necessary setup functions from the original script
1163
+ def find_path(name: str, path: str = None) -> str:
1164
+ if path is None:
1165
+ path = os.getcwd()
1166
+ if name in os.listdir(path):
1167
+ path_name = os.path.join(path, name)
1168
+ print(f"{name} found: {path_name}")
1169
+ return path_name
1170
+ parent_directory = os.path.dirname(path)
1171
+ if parent_directory == path:
1172
+ return None
1173
+ return find_path(name, parent_directory)
1174
+
1175
+ def add_comfyui_directory_to_sys_path() -> None:
1176
+ comfyui_path = find_path("ComfyUI")
1177
+ if comfyui_path is not None and os.path.isdir(comfyui_path):
1178
+ sys.path.append(comfyui_path)
1179
+ print(f"'{comfyui_path}' added to sys.path")
1180
+
1181
+ def add_extra_model_paths() -> None:
1182
+ try:
1183
+ from main import load_extra_path_config
1184
+ except ImportError:
1185
+ from utils.extra_config import load_extra_path_config
1186
+ extra_model_paths = find_path("extra_model_paths.yaml")
1187
+ if extra_model_paths is not None:
1188
+ load_extra_path_config(extra_model_paths)
1189
+ else:
1190
+ print("Could not find the extra_model_paths config file.")
1191
+
1192
+ # Initialize paths
1193
+ add_comfyui_directory_to_sys_path()
1194
+ add_extra_model_paths()
1195
+
1196
+ def import_custom_nodes() -> None:
1197
+ import asyncio
1198
+ import execution
1199
+ from nodes import init_extra_nodes
1200
+ import server
1201
+ loop = asyncio.new_event_loop()
1202
+ asyncio.set_event_loop(loop)
1203
+ server_instance = server.PromptServer(loop)
1204
+ execution.PromptQueue(server_instance)
1205
+ init_extra_nodes()
1206
+
1207
+ # Import all necessary nodes
1208
+ from nodes import (
1209
+ StyleModelLoader,
1210
+ VAEEncode,
1211
+ NODE_CLASS_MAPPINGS,
1212
+ LoadImage,
1213
+ CLIPVisionLoader,
1214
+ SaveImage,
1215
+ VAELoader,
1216
+ CLIPVisionEncode,
1217
+ DualCLIPLoader,
1218
+ EmptyLatentImage,
1219
+ VAEDecode,
1220
+ UNETLoader,
1221
+ CLIPTextEncode,
1222
+ )
1223
+
1224
+ # Initialize all constant nodes and models in global context
1225
+ import_custom_nodes()
1226
+
1227
+ # Global variables for preloaded models and constants
1228
+ #with torch.inference_mode():
1229
+ # Initialize constants
1230
+ intconstant = NODE_CLASS_MAPPINGS["INTConstant"]()
1231
+ CONST_1024 = intconstant.get_value(value=1024)
1232
+
1233
+ # Load CLIP
1234
+ dualcliploader = DualCLIPLoader()
1235
+ CLIP_MODEL = dualcliploader.load_clip(
1236
+ clip_name1="t5/t5xxl_fp16.safetensors",
1237
+ clip_name2="clip_l.safetensors",
1238
+ type="flux",
1239
+ )
1240
+
1241
+ # Load VAE
1242
+ vaeloader = VAELoader()
1243
+ VAE_MODEL = vaeloader.load_vae(vae_name="FLUX1/ae.safetensors")
1244
+
1245
+ # Load UNET
1246
+ unetloader = UNETLoader()
1247
+ UNET_MODEL = unetloader.load_unet(
1248
+ unet_name="flux1-depth-dev.safetensors", weight_dtype="default"
1249
+ )
1250
+
1251
+ # Load CLIP Vision
1252
+ clipvisionloader = CLIPVisionLoader()
1253
+ CLIP_VISION_MODEL = clipvisionloader.load_clip(
1254
+ clip_name="sigclip_vision_patch14_384.safetensors"
1255
+ )
1256
+
1257
+ # Load Style Model
1258
+ stylemodelloader = StyleModelLoader()
1259
+ STYLE_MODEL = stylemodelloader.load_style_model(
1260
+ style_model_name="flux1-redux-dev.safetensors"
1261
+ )
1262
+
1263
+ # Initialize samplers
1264
+ ksamplerselect = NODE_CLASS_MAPPINGS["KSamplerSelect"]()
1265
+ SAMPLER = ksamplerselect.get_sampler(sampler_name="euler")
1266
+
1267
+ # Initialize depth model
1268
+ cr_clip_input_switch = NODE_CLASS_MAPPINGS["CR Clip Input Switch"]()
1269
+ downloadandloaddepthanythingv2model = NODE_CLASS_MAPPINGS["DownloadAndLoadDepthAnythingV2Model"]()
1270
+ DEPTH_MODEL = downloadandloaddepthanythingv2model.loadmodel(
1271
+ model="depth_anything_v2_vitl_fp32.safetensors"
1272
+ )
1273
+ cliptextencode = CLIPTextEncode()
1274
+ loadimage = LoadImage()
1275
+ vaeencode = VAEEncode()
1276
+ fluxguidance = NODE_CLASS_MAPPINGS["FluxGuidance"]()
1277
+ instructpixtopixconditioning = NODE_CLASS_MAPPINGS["InstructPixToPixConditioning"]()
1278
+ clipvisionencode = CLIPVisionEncode()
1279
+ stylemodelapplyadvanced = NODE_CLASS_MAPPINGS["StyleModelApplyAdvanced"]()
1280
+ emptylatentimage = EmptyLatentImage()
1281
+ basicguider = NODE_CLASS_MAPPINGS["BasicGuider"]()
1282
+ basicscheduler = NODE_CLASS_MAPPINGS["BasicScheduler"]()
1283
+ randomnoise = NODE_CLASS_MAPPINGS["RandomNoise"]()
1284
+ samplercustomadvanced = NODE_CLASS_MAPPINGS["SamplerCustomAdvanced"]()
1285
+ vaedecode = VAEDecode()
1286
+ cr_text = NODE_CLASS_MAPPINGS["CR Text"]()
1287
+ saveimage = SaveImage()
1288
+ getimagesizeandcount = NODE_CLASS_MAPPINGS["GetImageSizeAndCount"]()
1289
+ depthanything_v2 = NODE_CLASS_MAPPINGS["DepthAnything_V2"]()
1290
+ imageresize = NODE_CLASS_MAPPINGS["ImageResize+"]()
1291
+
1292
+ @spaces.GPU
1293
+ def generate_image(prompt, structure_image, style_image, depth_strength=15, style_strength=0.5, progress=gr.Progress(track_tqdm=True)) -> str:
1294
+ """Main generation function that processes inputs and returns the path to the generated image."""
1295
+ with torch.inference_mode():
1296
+ # Set up CLIP
1297
+ clip_switch = cr_clip_input_switch.switch(
1298
+ Input=1,
1299
+ clip1=get_value_at_index(CLIP_MODEL, 0),
1300
+ clip2=get_value_at_index(CLIP_MODEL, 0),
1301
+ )
1302
+
1303
+ # Encode text
1304
+ text_encoded = cliptextencode.encode(
1305
+ text=prompt,
1306
+ clip=get_value_at_index(clip_switch, 0),
1307
+ )
1308
+ empty_text = cliptextencode.encode(
1309
+ text="",
1310
+ clip=get_value_at_index(clip_switch, 0),
1311
+ )
1312
+
1313
+ # Process structure image
1314
+ structure_img = loadimage.load_image(image=structure_image)
1315
+
1316
+ # Resize image
1317
+ resized_img = imageresize.execute(
1318
+ width=get_value_at_index(CONST_1024, 0),
1319
+ height=get_value_at_index(CONST_1024, 0),
1320
+ interpolation="bicubic",
1321
+ method="keep proportion",
1322
+ condition="always",
1323
+ multiple_of=16,
1324
+ image=get_value_at_index(structure_img, 0),
1325
+ )
1326
+
1327
+ # Get image size
1328
+ size_info = getimagesizeandcount.getsize(
1329
+ image=get_value_at_index(resized_img, 0)
1330
+ )
1331
+
1332
+ # Encode VAE
1333
+ vae_encoded = vaeencode.encode(
1334
+ pixels=get_value_at_index(size_info, 0),
1335
+ vae=get_value_at_index(VAE_MODEL, 0),
1336
+ )
1337
+
1338
+ # Process depth
1339
+ depth_processed = depthanything_v2.process(
1340
+ da_model=get_value_at_index(DEPTH_MODEL, 0),
1341
+ images=get_value_at_index(size_info, 0),
1342
+ )
1343
+
1344
+ # Apply Flux guidance
1345
+ flux_guided = fluxguidance.append(
1346
+ guidance=depth_strength,
1347
+ conditioning=get_value_at_index(text_encoded, 0),
1348
+ )
1349
+
1350
+ # Process style image
1351
+ style_img = loadimage.load_image(image=style_image)
1352
+
1353
+ # Encode style with CLIP Vision
1354
+ style_encoded = clipvisionencode.encode(
1355
+ crop="center",
1356
+ clip_vision=get_value_at_index(CLIP_VISION_MODEL, 0),
1357
+ image=get_value_at_index(style_img, 0),
1358
+ )
1359
+
1360
+ # Set up conditioning
1361
+ conditioning = instructpixtopixconditioning.encode(
1362
+ positive=get_value_at_index(flux_guided, 0),
1363
+ negative=get_value_at_index(empty_text, 0),
1364
+ vae=get_value_at_index(VAE_MODEL, 0),
1365
+ pixels=get_value_at_index(depth_processed, 0),
1366
+ )
1367
+
1368
+ # Apply style
1369
+ style_applied = stylemodelapplyadvanced.apply_stylemodel(
1370
+ strength=style_strength,
1371
+ conditioning=get_value_at_index(conditioning, 0),
1372
+ style_model=get_value_at_index(STYLE_MODEL, 0),
1373
+ clip_vision_output=get_value_at_index(style_encoded, 0),
1374
+ )
1375
+
1376
+ # Set up empty latent
1377
+ empty_latent = emptylatentimage.generate(
1378
+ width=get_value_at_index(resized_img, 1),
1379
+ height=get_value_at_index(resized_img, 2),
1380
+ batch_size=1,
1381
+ )
1382
+
1383
+ # Set up guidance
1384
+ guided = basicguider.get_guider(
1385
+ model=get_value_at_index(UNET_MODEL, 0),
1386
+ conditioning=get_value_at_index(style_applied, 0),
1387
+ )
1388
+
1389
+ # Set up scheduler
1390
+ schedule = basicscheduler.get_sigmas(
1391
+ scheduler="simple",
1392
+ steps=28,
1393
+ denoise=1,
1394
+ model=get_value_at_index(UNET_MODEL, 0),
1395
+ )
1396
+
1397
+ # Generate random noise
1398
+ noise = randomnoise.get_noise(noise_seed=random.randint(1, 2**64))
1399
+
1400
+ # Sample
1401
+ sampled = samplercustomadvanced.sample(
1402
+ noise=get_value_at_index(noise, 0),
1403
+ guider=get_value_at_index(guided, 0),
1404
+ sampler=get_value_at_index(SAMPLER, 0),
1405
+ sigmas=get_value_at_index(schedule, 0),
1406
+ latent_image=get_value_at_index(empty_latent, 0),
1407
+ )
1408
+
1409
+ # Decode VAE
1410
+ decoded = vaedecode.decode(
1411
+ samples=get_value_at_index(sampled, 0),
1412
+ vae=get_value_at_index(VAE_MODEL, 0),
1413
+ )
1414
+
1415
+ # Save image
1416
+ prefix = cr_text.text_multiline(text="Flux_BFL_Depth_Redux")
1417
+
1418
+ saved = saveimage.save_images(
1419
+ filename_prefix=get_value_at_index(prefix, 0),
1420
+ images=get_value_at_index(decoded, 0),
1421
+ )
1422
+ saved_path = f"output/{saved['ui']['images'][0]['filename']}"
1423
+ return saved_path
1424
+
1425
+ # Create Gradio interface
1426
+
1427
+ examples = [
1428
+ ["", "mona.png", "receita-tacos.webp", 15, 0.6],
1429
+ ["a woman looking at a house catching fire on the background", "disaster_girl.png", "abaporu.jpg", 15, 0.15],
1430
+ ["istanbul aerial, dramatic photography", "natasha.png", "istambul.jpg", 15, 0.5],
1431
+ ]
1432
+
1433
+ output_image = gr.Image(label="Generated Image")
1434
+
1435
+ with gr.Blocks() as app:
1436
+ with gr.Tab("Relighting"):
1437
+ with gr.Row():
1438
+ gr.Markdown("## Product Placement from Text")
1439
+ with gr.Row():
1440
+ with gr.Column():
1441
+ with gr.Row():
1442
+ input_fg = gr.Image(type="pil", label="Image", height=480)
1443
+ with gr.Row():
1444
+ with gr.Group():
1445
+ find_objects_button = gr.Button(value="(Option 1) Segment Object from text")
1446
+ text_prompt = gr.Textbox(
1447
+ label="Text Prompt",
1448
+ placeholder="Enter object classes separated by periods (e.g. 'car . person .'), leave empty to get all objects",
1449
+ value=""
1450
+ )
1451
+ extract_button = gr.Button(value="Remove Background")
1452
+ with gr.Row():
1453
+ extracted_objects = gr.Image(type="numpy", label="Extracted Foreground", height=480)
1454
+ extracted_fg = gr.Image(type="pil", label="Extracted Foreground", height=480)
1455
+ angles_fg = gr.Image(type="pil", label="Converted Foreground", height=480, visible=False)
1456
+
1457
+
1458
+
1459
+ # output_bg = gr.Image(type="numpy", label="Preprocessed Foreground", height=480)
1460
+ with gr.Group():
1461
+ run_button = gr.Button("Generate alternative angles")
1462
+ orientation_result = gr.Gallery(
1463
+ label="Result",
1464
+ show_label=False,
1465
+ columns=[3],
1466
+ rows=[2],
1467
+ object_fit="fill",
1468
+ height="auto",
1469
+ allow_preview=False,
1470
+ )
1471
+
1472
+ if orientation_result:
1473
+ orientation_result.select(use_orientation, inputs=None, outputs=extracted_fg)
1474
+
1475
+ dummy_image_for_outputs = gr.Image(visible=False, label='Result')
1476
+
1477
+
1478
+ with gr.Column():
1479
+ result_gallery = gr.Gallery(height=832, object_fit='contain', label='Outputs')
1480
+
1481
+ with gr.Row():
1482
+ with gr.Group():
1483
+ prompt = gr.Textbox(label="Prompt")
1484
+ bg_source = gr.Radio(choices=[e.value for e in list(BGSource)[2:]],
1485
+ value=BGSource.LEFT.value,
1486
+ label="Lighting Preference (Initial Latent)", type='value')
1487
+
1488
+ example_quick_subjects = gr.Dataset(samples=quick_subjects, label='Subject Quick List', samples_per_page=1000, components=[prompt])
1489
+ example_quick_prompts = gr.Dataset(samples=quick_prompts, label='Lighting Quick List', samples_per_page=1000, components=[prompt])
1490
+ with gr.Row():
1491
+ relight_button = gr.Button(value="Relight")
1492
+
1493
+ with gr.Group(visible=False):
1494
+ with gr.Row():
1495
+ num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=1, step=1)
1496
+ seed = gr.Number(label="Seed", value=12345, precision=0)
1497
+
1498
+ with gr.Row():
1499
+ image_width = gr.Slider(label="Image Width", minimum=256, maximum=1024, value=512, step=64)
1500
+ image_height = gr.Slider(label="Image Height", minimum=256, maximum=1024, value=640, step=64)
1501
+
1502
+ with gr.Accordion("Advanced options", open=False):
1503
+ steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=15, step=1)
1504
+ cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=2, step=0.01, visible=False)
1505
+ lowres_denoise = gr.Slider(label="Lowres Denoise (for initial latent)", minimum=0.1, maximum=1.0, value=0.9, step=0.01)
1506
+ highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01)
1507
+ highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=1.0, value=0.5, step=0.01)
1508
+ a_prompt = gr.Textbox(label="Added Prompt", value='best quality', visible=False)
1509
+ n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality', visible=False)
1510
+ x_slider = gr.Slider(
1511
+ minimum=0,
1512
+ maximum=1000,
1513
+ label="X Position",
1514
+ value=500,
1515
+ visible=False
1516
+ )
1517
+ y_slider = gr.Slider(
1518
+ minimum=0,
1519
+ maximum=1000,
1520
+ label="Y Position",
1521
+ value=500,
1522
+ visible=False
1523
+ )
1524
+
1525
+ # with gr.Row():
1526
+
1527
+ # gr.Examples(
1528
+ # fn=lambda *args: ([args[-1]], None),
1529
+ # examples=db_examples.foreground_conditioned_examples,
1530
+ # inputs=[
1531
+ # input_fg, prompt, bg_source, image_width, image_height, seed, dummy_image_for_outputs
1532
+ # ],
1533
+ # outputs=[result_gallery, output_bg],
1534
+ # run_on_click=True, examples_per_page=1024
1535
+ # )
1536
+ ips = [extracted_fg, prompt, image_width, image_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, lowres_denoise, bg_source]
1537
+ relight_button.click(fn=process_relight, inputs=ips, outputs=[result_gallery])
1538
+ example_quick_prompts.click(lambda x, y: ', '.join(y.split(', ')[:2] + [x[0]]), inputs=[example_quick_prompts, prompt], outputs=prompt, show_progress=False, queue=False)
1539
+ example_quick_subjects.click(lambda x: x[0], inputs=example_quick_subjects, outputs=prompt, show_progress=False, queue=False)
1540
+
1541
+
1542
+ def convert_to_pil(image):
1543
+ try:
1544
+ #logging.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
1545
+ image = image.astype(np.uint8)
1546
+ logging.info(f"Converted image shape: {image.shape}, dtype: {image.dtype}")
1547
+ return image
1548
+ except Exception as e:
1549
+ logging.error(f"Error converting image: {e}")
1550
+ return image
1551
+
1552
+ run_button.click(
1553
+ fn=convert_to_pil,
1554
+ inputs=extracted_fg, # This is already RGBA with removed background
1555
+ outputs=angles_fg
1556
+ ).then(
1557
+ fn=infer,
1558
+ inputs=[
1559
+ text_prompt,
1560
+ extracted_fg, # Already processed RGBA image
1561
+ ],
1562
+ outputs=[orientation_result],
1563
+ )
1564
+
1565
+ find_objects_button.click(
1566
+ fn=process_image,
1567
+ inputs=[input_fg, text_prompt],
1568
+ outputs=[extracted_objects, extracted_fg]
1569
+ )
1570
+
1571
+ extract_button.click(
1572
+ fn=extract_foreground,
1573
+ inputs=[input_fg],
1574
+ outputs=[extracted_fg, x_slider, y_slider]
1575
+ )
1576
+ gr.Tab("FLUX Style Shaping")
1577
+ gr.Markdown("Flux[dev] Redux + Flux[dev] Depth ComfyUI workflow by [Nathan Shipley](https://x.com/CitizenPlain) running directly on Gradio. [workflow](https://gist.github.com/nathanshipley/7a9ac1901adde76feebe58d558026f68) - [how to convert your any comfy workflow to gradio (soon)](#)")
1578
+ with gr.Row():
1579
+ with gr.Column():
1580
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Enter your prompt here...")
1581
+ with gr.Row():
1582
+ with gr.Group():
1583
+ structure_image = gr.Image(label="Structure Image", type="filepath")
1584
+ depth_strength = gr.Slider(minimum=0, maximum=50, value=15, label="Depth Strength")
1585
+ with gr.Group():
1586
+ style_image = gr.Image(label="Style Image", type="filepath")
1587
+ style_strength = gr.Slider(minimum=0, maximum=1, value=0.5, label="Style Strength")
1588
+ generate_btn = gr.Button("Generate")
1589
+
1590
+ gr.Examples(
1591
+ examples=examples,
1592
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1593
+ outputs=[output_image],
1594
+ fn=generate_image,
1595
+ cache_examples=True,
1596
+ cache_mode="lazy"
1597
+ )
1598
+
1599
+ with gr.Column():
1600
+ output_image.render()
1601
+ generate_btn.click(
1602
+ fn=generate_image,
1603
+ inputs=[prompt_input, structure_image, style_image, depth_strength, style_strength],
1604
+ outputs=[output_image]
1605
+ )
1606
+
1607
+ if __name__ == "__main__":
1608
+ app.launch(share=True)