lehduong commited on
Commit
2db0e1a
1 Parent(s): de480b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +699 -124
app.py CHANGED
@@ -1,154 +1,729 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
- import random
4
 
5
- # import spaces #[uncomment to use ZeroGPU]
6
- from diffusers import DiffusionPipeline
7
- import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
- model_repo_id = "stabilityai/sdxl-turbo" # Replace to the model you would like to use
 
 
 
 
 
 
11
 
12
- if torch.cuda.is_available():
13
- torch_dtype = torch.float16
14
- else:
15
- torch_dtype = torch.float32
16
 
17
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
18
- pipe = pipe.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- MAX_SEED = np.iinfo(np.int32).max
21
- MAX_IMAGE_SIZE = 1024
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
 
24
- # @spaces.GPU #[uncomment to use ZeroGPU]
25
- def infer(
26
- prompt,
27
- negative_prompt,
28
- seed,
29
- randomize_seed,
30
- width,
31
- height,
32
- guidance_scale,
33
- num_inference_steps,
34
- progress=gr.Progress(track_tqdm=True),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  ):
36
- if randomize_seed:
37
- seed = random.randint(0, MAX_SEED)
38
-
39
- generator = torch.Generator().manual_seed(seed)
40
-
41
- image = pipe(
42
- prompt=prompt,
43
- negative_prompt=negative_prompt,
44
- guidance_scale=guidance_scale,
45
- num_inference_steps=num_inference_steps,
46
- width=width,
47
- height=height,
48
- generator=generator,
49
- ).images[0]
50
-
51
- return image, seed
52
-
53
-
54
- examples = [
55
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
56
- "An astronaut riding a green horse",
57
- "A delicious ceviche cheesecake slice",
58
- ]
59
-
60
- css = """
61
- #col-container {
62
- margin: 0 auto;
63
- max-width: 640px;
64
- }
65
- """
66
 
67
- with gr.Blocks(css=css) as demo:
68
- with gr.Column(elem_id="col-container"):
69
- gr.Markdown(" # Text-to-Image Gradio Template")
 
70
 
71
- with gr.Row():
72
- prompt = gr.Text(
73
- label="Prompt",
74
- show_label=False,
75
- max_lines=1,
76
- placeholder="Enter your prompt",
77
- container=False,
78
- )
79
 
80
- run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
 
 
 
 
 
81
 
82
- result = gr.Image(label="Result", show_label=False)
 
 
 
83
 
84
- with gr.Accordion("Advanced Settings", open=False):
85
- negative_prompt = gr.Text(
86
- label="Negative prompt",
87
- max_lines=1,
88
- placeholder="Enter a negative prompt",
89
- visible=False,
90
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
- seed = gr.Slider(
93
- label="Seed",
94
- minimum=0,
95
- maximum=MAX_SEED,
96
- step=1,
97
- value=0,
 
 
 
 
 
 
 
 
 
 
 
 
98
  )
 
 
 
99
 
100
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
- with gr.Row():
103
- width = gr.Slider(
104
- label="Width",
105
- minimum=256,
106
- maximum=MAX_IMAGE_SIZE,
107
- step=32,
108
- value=1024, # Replace with defaults that work for your model
109
- )
110
 
111
- height = gr.Slider(
112
- label="Height",
113
- minimum=256,
114
- maximum=MAX_IMAGE_SIZE,
115
- step=32,
116
- value=1024, # Replace with defaults that work for your model
117
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  with gr.Row():
120
- guidance_scale = gr.Slider(
121
- label="Guidance scale",
122
- minimum=0.0,
123
- maximum=10.0,
124
- step=0.1,
125
- value=0.0, # Replace with defaults that work for your model
 
126
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
 
128
- num_inference_steps = gr.Slider(
129
- label="Number of inference steps",
130
- minimum=1,
131
- maximum=50,
132
- step=1,
133
- value=2, # Replace with defaults that work for your model
134
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- gr.Examples(examples=examples, inputs=[prompt])
137
- gr.on(
138
- triggers=[run_button.click, prompt.submit],
139
- fn=infer,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  inputs=[
141
- prompt,
142
- negative_prompt,
143
- seed,
144
- randomize_seed,
145
- width,
146
- height,
147
- guidance_scale,
148
- num_inference_steps,
149
  ],
150
- outputs=[result, seed],
 
 
 
 
 
 
 
 
 
 
 
 
151
  )
152
 
153
  if __name__ == "__main__":
154
- demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from transformers import (
7
+ LlavaNextProcessor, LlavaNextForConditionalGeneration,
8
+ T5EncoderModel, T5Tokenizer
9
+ )
10
+ from transformers import (
11
+ AutoProcessor, AutoModelForCausalLM, GenerationConfig,
12
+ T5EncoderModel, T5Tokenizer
13
+ )
14
+ from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FluxPipeline
15
+ from tordi.diffusion.pipelines.onediffusion import OneDiffusionPipeline
16
+ from tordi.models.denoiser.nextdit import NextDiT
17
+ from tordi.dataset.utils import get_closest_ratio, ASPECT_RATIO_512
18
+ from typing import List, Optional
19
+
20
+ # Import additional libraries
21
+ import matplotlib
22
  import numpy as np
23
+ import cv2
24
 
25
+ # Task-specific tokens
26
+ TASK2SPECIAL_TOKENS = {
27
+ "text2image": "[[text2image]]",
28
+ "deblurring": "[[deblurring]]",
29
+ "inpainting": "[[image_inpainting]]",
30
+ "canny": "[[canny2image]]",
31
+ "super_resolution": "[[super_resolution]]",
32
+ "depth2image": "[[depth2image]]",
33
+ "hed2image": "[[hed2img]]",
34
+ "pose2image": "[[pose2image]]",
35
+ "semanticmap2image": "[[semanticmap2image]]",
36
+ "boundingbox2image": "[[boundingbox2image]]",
37
+ "image_editing": "[[image_editing]]",
38
+ "faceid": "[[faceid]]",
39
+ "multiview": "[[multiview]]",
40
+ "subject_driven": "[[subject_driven]]"
41
+ }
42
+ NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
43
+
44
+
45
+ class LlavaCaptionProcessor:
46
+ def __init__(self):
47
+ model_name = "llava-hf/llama3-llava-next-8b-hf"
48
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
49
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
50
+ self.processor = LlavaNextProcessor.from_pretrained(model_name)
51
+ self.model = LlavaNextForConditionalGeneration.from_pretrained(
52
+ model_name, torch_dtype=dtype, low_cpu_mem_usage=True
53
+ ).to(device)
54
+ self.SPECIAL_TOKENS = "assistant\n\n\n"
55
+
56
+ def generate_response(self, image: Image.Image, msg: str) -> str:
57
+ conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
58
+ with torch.no_grad():
59
+ prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
60
+ inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
61
+ output = self.model.generate(**inputs, max_new_tokens=250)
62
+ response = self.processor.decode(output[0], skip_special_tokens=True)
63
+ return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
64
 
65
+ def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
66
+ if msg is None:
67
+ msg = f"Describe the contents of the photo in 150 words or fewer."
68
+ try:
69
+ return [self.generate_response(img, msg) for img in images]
70
+ except Exception as e:
71
+ print(f"Error in process: {str(e)}")
72
+ raise
73
 
 
 
 
 
74
 
75
+ class MolmoCaptionProcessor:
76
+ def __init__(self):
77
+ pretrained_model_name = 'allenai/Molmo-7B-O-0924'
78
+ self.processor = AutoProcessor.from_pretrained(
79
+ pretrained_model_name,
80
+ trust_remote_code=True,
81
+ torch_dtype='auto',
82
+ device_map='auto'
83
+ )
84
+ self.model = AutoModelForCausalLM.from_pretrained(
85
+ pretrained_model_name,
86
+ trust_remote_code=True,
87
+ torch_dtype='auto',
88
+ device_map='auto'
89
+ )
90
 
91
+ def generate_response(self, image: Image.Image, msg: str) -> str:
92
+ inputs = self.processor.process(
93
+ images=[image],
94
+ text=msg
95
+ )
96
+ # Move inputs to the correct device and make a batch of size 1
97
+ inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
98
+
99
+ # Generate output
100
+ output = self.model.generate_from_batch(
101
+ inputs,
102
+ GenerationConfig(max_new_tokens=250, stop_strings="<|endoftext|>"),
103
+ tokenizer=self.processor.tokenizer
104
+ )
105
+
106
+ # Only get generated tokens and decode them to text
107
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
108
+ return self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
109
 
110
 
111
+ def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
112
+ if msg is None:
113
+ msg = f"Describe the contents of the photo in 150 words or fewer."
114
+ try:
115
+ return [self.generate_response(img, msg) for img in images]
116
+ except Exception as e:
117
+ print(f"Error in process: {str(e)}")
118
+ raise
119
+
120
+ def initialize_models():
121
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
122
+ model = NextDiT.from_pretrained(
123
+ # "/data/input/duongl/finetuning_distributed_multiview_16x8_scalerays_dl3dv_dynamic_shift_softcap_editing/checkpoint-98000",
124
+ # "/data/input/duongl/data/input/duongl/finetuning_distributed_multiview_16x8_scalerays_dl3dv_dynamic_shift_softcap_trainingWithFluxScheduler/checkpoint-10000/", # "lehduong/OneDiffusion",
125
+ "lehduong/OneDiffusion",
126
+ subfolder="transformer",
127
+ torch_dtype=torch.float32,
128
+ ).to(device)
129
+ vae = AutoencoderKL.from_pretrained("lehduong/OneDiffusion", subfolder="vae").to(device)
130
+ text_encoder = T5EncoderModel.from_pretrained("lehduong/OneDiffusion", subfolder="text_encoder", torch_dtype=torch.float16).to(device)
131
+ tokenizer = T5Tokenizer.from_pretrained("lehduong/OneDiffusion", subfolder="tokenizer")
132
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
133
+ # "stabilityai/stable-diffusion-3-medium-diffusers",
134
+ # "black-forest-labs/FLUX.1-dev",
135
+ "lehduong/OneDiffusion",
136
+ subfolder="scheduler"
137
+ )
138
+ # scheduler = FlowMatchEulerDiscreteScheduler(
139
+ # base_image_seq_len=256,
140
+ # base_shift=0.5,
141
+ # max_image_seq_len=4096,
142
+ # max_shift=1.16,
143
+ # num_train_timesteps=1000,
144
+ # shift=3.0,
145
+ # use_dynamic_shifting=True
146
+ # )
147
+ pipeline = OneDiffusionPipeline(
148
+ vae=vae, text_encoder=text_encoder, transformer=model, tokenizer=tokenizer, scheduler=scheduler
149
+ ).to(torch.bfloat16)
150
+ molmo_caption_processor = MolmoCaptionProcessor() # LlavaCaptionProcessor()
151
+ return pipeline, molmo_caption_processor
152
+
153
+ def colorize_depth_maps(
154
+ depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
155
  ):
156
+ """
157
+ Colorize depth maps with reversed colors.
158
+ """
159
+ assert len(depth_map.shape) >= 2, "Invalid dimension"
160
+
161
+ if isinstance(depth_map, torch.Tensor):
162
+ depth = depth_map.detach().squeeze().numpy()
163
+ elif isinstance(depth_map, np.ndarray):
164
+ depth = depth_map.copy().squeeze()
165
+ # reshape to [ (B,) H, W ]
166
+ if depth.ndim < 3:
167
+ depth = depth[np.newaxis, :, :]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
+ # Normalize depth values to [0, 1]
170
+ depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
171
+ # Invert the depth values to reverse the colors
172
+ depth = 1 - depth
173
 
174
+ # Use the colormap
175
+ cm = matplotlib.colormaps[cmap]
176
+ img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # values from 0 to 1
177
+ img_colored_np = np.rollaxis(img_colored_np, 3, 1)
 
 
 
 
178
 
179
+ if valid_mask is not None:
180
+ if isinstance(depth_map, torch.Tensor):
181
+ valid_mask = valid_mask.detach().numpy()
182
+ valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
183
+ if valid_mask.ndim < 3:
184
+ valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
185
+ else:
186
+ valid_mask = valid_mask[:, np.newaxis, :, :]
187
+ valid_mask = np.repeat(valid_mask, 3, axis=1)
188
+ img_colored_np[~valid_mask] = 0
189
 
190
+ if isinstance(depth_map, torch.Tensor):
191
+ img_colored = torch.from_numpy(img_colored_np).float()
192
+ elif isinstance(depth_map, np.ndarray):
193
+ img_colored = img_colored_np
194
 
195
+ return img_colored
196
+
197
+
198
+ def format_prompt(task_type: str, captions: List[str]) -> str:
199
+ if not captions:
200
+ return ""
201
+ if task_type == "faceid":
202
+ img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions, start=1)]
203
+ return f"[[faceid]] [[img0]] insert/your/caption/here {' '.join(img_prompts)}"
204
+ elif task_type == "image_editing":
205
+ target_caption = captions[0] if len(captions) >= 2 else "Insert target caption here"
206
+ source_caption = captions[1] if len(captions) >= 2 else captions[0]
207
+ return f"[[image_editing]] [[target_caption]] {target_caption} [[source_caption]] {source_caption}"
208
+ elif task_type == "semanticmap2image":
209
+ return f"[[semanticmap2image]] <#00ffff Cyan mask: insert/concept/to/segment/here> {captions[0]}"
210
+ elif task_type == "boundingbox2image":
211
+ return f"[[boundingbox2image]] <#00ffff Cyan boundingbox: insert/concept/to/segment/here> {captions[0]}"
212
+ elif task_type == "multiview":
213
+ # img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions)]
214
+ img_prompts = captions[0]
215
+ return f"[[multiview]] {img_prompts}"
216
+ elif task_type == "subject_driven":
217
+ return f"[[subject_driven]] <item: insert/item/here> [[img0]] insert/your/target/caption/here [[img1]] {captions[0]}"
218
+ else:
219
+ return f"{TASK2SPECIAL_TOKENS[task_type]} {captions[0]}"
220
+
221
+ def update_prompt(images: List[Image.Image], task_type: str, custom_msg: str = None):
222
+ if not images:
223
+ return format_prompt(task_type, []), "Please upload at least one image!"
224
+ try:
225
+ captions = molmo_processor.process(images, custom_msg)
226
+ if not captions:
227
+ return "", "No valid images found!"
228
+ prompt = format_prompt(task_type, captions)
229
+ return prompt, f"Generated {len(captions)} captions successfully!"
230
+ except Exception as e:
231
+ return "", f"Error generating captions: {str(e)}"
232
+
233
+ def generate_image(images: List[Image.Image], prompt: str, negative_prompt: str, num_inference_steps: int, guidance_scale: float, pag_guidance_scale: float,
234
+ denoise_mask: List[str], task_type: str, azimuth: str, elevation: str, distance: str, focal_length: float,
235
+ height: int = 1024, width: int = 1024, scale_factor: float = 1.0, scale_watershed: float = 1.0,
236
+ noise_scale: float = None, progress=gr.Progress()):
237
+ try:
238
+ img2img_kwargs = {
239
+ 'prompt': prompt,
240
+ 'negative_prompt': negative_prompt,
241
+ 'num_inference_steps': num_inference_steps,
242
+ 'guidance_scale': guidance_scale,
243
+ 'height': height,
244
+ 'width': width,
245
+ 'forward_kwargs': {
246
+ 'scale_factor': scale_factor,
247
+ 'scale_watershed': scale_watershed
248
+ },
249
+ 'noise_scale': noise_scale # Added noise_scale here
250
+ }
251
+
252
+ if task_type == 'multiview':
253
+ # Parse azimuth, elevation, and distance into lists, allowing 'None' values
254
+ azimuths = [float(a.strip()) if a.strip().lower() != 'none' else None for a in azimuth.split(',')] if azimuth else []
255
+ elevations = [float(e.strip()) if e.strip().lower() != 'none' else None for e in elevation.split(',')] if elevation else []
256
+ distances = [float(d.strip()) if d.strip().lower() != 'none' else None for d in distance.split(',')] if distance else []
257
+
258
+ num_views = max(len(images), len(azimuths), len(elevations), len(distances))
259
+ if num_views == 0:
260
+ return None, "At least one image or camera parameter must be provided."
261
+
262
+ total_components = []
263
+ for i in range(num_views):
264
+ total_components.append(f"image_{i}")
265
+ total_components.append(f"camera_pose_{i}")
266
+
267
+ denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
268
+
269
+ if len(denoise_mask_int) != len(total_components):
270
+ return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
271
+
272
+ # Pad the input lists to num_views length
273
+ images_padded = images + [] * (num_views - len(images)) # Do not add None
274
+ azimuths_padded = azimuths + [None] * (num_views - len(azimuths))
275
+ elevations_padded = elevations + [None] * (num_views - len(elevations))
276
+ distances_padded = distances + [None] * (num_views - len(distances))
277
+
278
+ # Prepare values
279
+ img2img_kwargs.update({
280
+ 'image': images_padded,
281
+ 'multiview_azimuths': azimuths_padded,
282
+ 'multiview_elevations': elevations_padded,
283
+ 'multiview_distances': distances_padded,
284
+ 'multiview_focal_length': focal_length, # Pass focal_length here
285
+ 'is_multiview': True,
286
+ 'denoise_mask': denoise_mask_int,
287
+ # 'predict_camera_poses': True,
288
+ })
289
+ else:
290
+ total_components = ["image_0"] + [f"image_{i+1}" for i in range(len(images))]
291
+ denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
292
+ if len(denoise_mask_int) != len(total_components):
293
+ return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
294
 
295
+ img2img_kwargs.update({
296
+ 'image': images,
297
+ 'denoise_mask': denoise_mask_int
298
+ })
299
+
300
+ progress(0, desc="Generating image...")
301
+ if task_type == 'text2image':
302
+ output = pipeline(
303
+ prompt=prompt,
304
+ negative_prompt=negative_prompt,
305
+ num_inference_steps=num_inference_steps,
306
+ guidance_scale=guidance_scale,
307
+ pag_guidance_scale=pag_guidance_scale,
308
+ height=height,
309
+ width=width,
310
+ scale_factor=scale_factor,
311
+ scale_watershed=scale_watershed,
312
+ noise_scale=noise_scale # Added noise_scale here
313
  )
314
+ else:
315
+ output = pipeline.img2img(**img2img_kwargs)
316
+ progress(1, desc="Done!")
317
 
318
+ # Process the output images if task is 'depth2image' and predicting depth
319
+ if task_type == 'depth2image' and denoise_mask_int[-1] == 1:
320
+ processed_images = []
321
+ for img in output.images:
322
+ depth_map = np.array(img.convert('L')) # Convert to grayscale numpy array
323
+ min_depth = depth_map.min()
324
+ max_depth = depth_map.max()
325
+ colorized = colorize_depth_maps(depth_map, min_depth, max_depth)[0]
326
+ colorized = np.transpose(colorized, (1, 2, 0))
327
+ colorized = (colorized * 255).astype(np.uint8)
328
+ img_colorized = Image.fromarray(colorized)
329
+ processed_images.append(img_colorized)
330
+ output_images = processed_images + output.images
331
+ elif task_type in ['boundingbox2image', 'semanticmap2image'] and denoise_mask_int == [0,1] and images:
332
+ # Interpolate between input and output images
333
+ processed_images = []
334
+ for input_img, output_img in zip(images, output.images):
335
+ input_img_resized = input_img.resize(output_img.size)
336
+ blended_img = Image.blend(input_img_resized, output_img, alpha=0.5)
337
+ processed_images.append(blended_img)
338
+ output_images = processed_images + output.images
339
+ else:
340
+ output_images = output.images
341
 
342
+ return output_images, "Generation completed successfully!"
 
 
 
 
 
 
 
343
 
344
+ except Exception as e:
345
+ return None, f"Error during generation: {str(e)}"
346
+
347
+ def update_denoise_checkboxes(images_state: List[Image.Image], task_type: str, azimuth: str, elevation: str, distance: str):
348
+ if task_type == 'multiview':
349
+ azimuths = [a.strip() for a in azimuth.split(',')] if azimuth else []
350
+ elevations = [e.strip() for e in elevation.split(',')] if elevation else []
351
+ distances = [d.strip() for d in distance.split(',')] if distance else []
352
+ images_len = len(images_state)
353
+
354
+ num_views = max(images_len, len(azimuths), len(elevations), len(distances))
355
+ if num_views == 0:
356
+ return gr.update(choices=[], value=[]), "Please provide at least one image or camera parameter."
357
+
358
+ # Pad lists to the same length
359
+ azimuths += ['None'] * (num_views - len(azimuths))
360
+ elevations += ['None'] * (num_views - len(elevations))
361
+ distances += ['None'] * (num_views - len(distances))
362
+ # Do not add None to images_state
363
+
364
+ labels = []
365
+ values = []
366
+ for i in range(num_views):
367
+ labels.append(f"image_{i}")
368
+ labels.append(f"camera_pose_{i}")
369
+
370
+ # Default behavior: condition on provided inputs, generate missing ones
371
+ if i >= images_len:
372
+ values.append(f"image_{i}")
373
+ if azimuths[i].lower() == 'none' or elevations[i].lower() == 'none' or distances[i].lower() == 'none':
374
+ values.append(f"camera_pose_{i}")
375
+
376
+ return gr.update(choices=labels, value=values)
377
+ else:
378
+ labels = ["image_0"] + [f"image_{i+1}" for i in range(len(images_state))]
379
+ values = ["image_0"]
380
+ return gr.update(choices=labels, value=values)
381
+
382
+ def apply_mask(images_state):
383
+ if len(images_state) < 2:
384
+ return None, "Please upload at least two images: first as the base image, second as the mask."
385
+ base_img = images_state[0]
386
+ mask_img = images_state[1]
387
+
388
+ # Convert images to arrays
389
+ base_arr = np.array(base_img)
390
+ mask_arr = np.array(mask_img)
391
+
392
+ # Convert mask to grayscale
393
+ if mask_arr.ndim == 3:
394
+ gray_mask = cv2.cvtColor(mask_arr, cv2.COLOR_RGB2GRAY)
395
+ else:
396
+ gray_mask = mask_arr
397
+
398
+ # Create a binary mask where non-black pixels are True
399
+ binary_mask = gray_mask > 10
400
+
401
+ # Define the gray color
402
+ gray_color = np.array([128, 128, 128], dtype=np.uint8)
403
+
404
+ # Apply gray color where mask is True
405
+ masked_arr = base_arr.copy()
406
+ masked_arr[binary_mask] = gray_color
407
+
408
+ masked_img = Image.fromarray(masked_arr)
409
+ return [masked_img], "Mask applied successfully!"
410
+
411
+ def process_images_for_task_type(images_state: List[Image.Image], task_type: str):
412
+ # No changes needed here since we are processing the output images
413
+ return images_state, images_state
414
+
415
+ # Initialize models
416
+ pipeline, molmo_processor = initialize_models()
417
+
418
+ with gr.Blocks(title="OneDiffusion Demo") as demo:
419
+ gr.Markdown("""
420
+ # OneDiffusion Demo
421
+
422
+ **Welcome to the OneDiffusion Demo!**
423
+
424
+ This application allows you to generate images based on your input prompts for various tasks. Here's how to use it:
425
+
426
+ 1. **Select Task Type**: Choose the type of task you want to perform from the "Task Type" dropdown menu.
427
+
428
+ 2. **Upload Images**: Drag and drop images directly onto the upload area, or click to select files from your device.
429
 
430
+ 3. **Generate Captions**: **If you upload any images**, Click the "Generate Captions with Molmo" button to generate descriptive captions for your uploaded images (depend on the task). You can enter a custom message in the "Custom Message for Molmo" textbox e.g., "caption in 50 words" instead of 100 words.
431
+
432
+ 4. **Configure Generation Settings**: Expand the "Advanced Configuration" section to adjust parameters like the number of inference steps, guidance scale, image size, and more.
433
+
434
+ 5. **Generate Images**: After setting your preferences, click the "Generate Image" button. The generated images will appear in the "Generated Images" gallery.
435
+
436
+ 6. **Manage Images**: Use the "Delete Selected Images" or "Delete All Images" buttons to remove unwanted images from the gallery.
437
+
438
+ **Notes**:
439
+ - For text-to-image:
440
+ + simply enter your prompt in this format "[[text2image]] your/prompt/here" and press the "Generate Image" button.
441
+
442
+ - For boundingbox2image/semantic2image/inpainting etc tasks:
443
+ + To perform condition-to-image such as semantic map to image, follow above steps
444
+ + For image-to-condition e.g., image to depth, change the denoise_mask checkbox before generating images. You must UNCHECK image_0 box and CHECK image_1 box.
445
+
446
+ - For FaceID tasks:
447
+ + Use 3 or 4 images if single input image does not give satisfactory results.
448
+ + All images will be resized and center cropped to the input height and width. You should choose height and width so that faces in input images won't be cropped.
449
+ + Model works best with close-up portrait (input and output) images.
450
+ + If the model does not conform your text prompt, try using shorter caption for source image(s).
451
+ + If you have non-human subjects and does not get satisfactory results, try "copying" part of caption of source images where it describes the properties of the subject e.g., a monster with red eyes, sharp teeth, etc.
452
+
453
+ - For Multiview generation:
454
+ + Only support square images (ideally in 512x512 resolution).
455
+ + Ensure the number of elevations, azimuths, and distances are equal.
456
+ + The model generally works well for 2-5 views (include both input and generated images). Since the model is trained with 3 views on 512x512 resolution, you might try scale_factor of [1.1; 1.5] and scale_watershed of [100; 400] for better extrapolation.
457
+ + For better results:
458
+ 1) try increasing num_inference_steps to 75-100.
459
+ 2) avoid aggressively changes in target camera poses, for example to generate novel views at azimuth of 180, (simultaneously) generate 4 views with azimuth of 45, 90, 135, 180.
460
+
461
+ Enjoy creating images with OneDiffusion!
462
+ """)
463
+
464
+ with gr.Row():
465
+ with gr.Column():
466
+ images_state = gr.State([])
467
+ selected_indices_state = gr.State([])
468
+
469
  with gr.Row():
470
+ gallery = gr.Gallery(
471
+ label="Input Images",
472
+ show_label=True,
473
+ columns=2,
474
+ rows=2,
475
+ height="auto",
476
+ object_fit="contain"
477
  )
478
+
479
+ # In the UI section, update the file_output component:
480
+ file_output = gr.File(
481
+ file_count="multiple",
482
+ file_types=["image"],
483
+ label="Drag and drop images here or click to upload",
484
+ height=100,
485
+ scale=2,
486
+ type="filepath" # Add this parameter
487
+ )
488
+
489
+ with gr.Row():
490
+ delete_button = gr.Button("Delete Selected Images")
491
+ delete_all_button = gr.Button("Delete All Images")
492
+
493
+ task_type = gr.Dropdown(
494
+ choices=list(TASK2SPECIAL_TOKENS.keys()),
495
+ value="text2image",
496
+ label="Task Type"
497
+ )
498
+
499
+ molmo_message = gr.Textbox(
500
+ lines=2,
501
+ value="Describe the contents of the photo in 100 words.",
502
+ label="Custom message for Molmo captioner"
503
+ )
504
+
505
+ auto_caption_btn = gr.Button("Generate Captions with Molmo")
506
 
507
+ with gr.Column():
508
+ prompt = gr.Textbox(
509
+ lines=3,
510
+ placeholder="Enter your prompt here or use auto-caption...",
511
+ label="Prompt"
512
+ )
513
+ negative_prompt = gr.Textbox(
514
+ lines=3,
515
+ value=NEGATIVE_PROMPT,
516
+ placeholder="Enter negative prompt here...",
517
+ label="Negative Prompt"
518
+ )
519
+ caption_status = gr.Textbox(label="Caption Status")
520
+
521
+ num_steps = gr.Slider(
522
+ minimum=1,
523
+ maximum=200,
524
+ value=30,
525
+ step=1,
526
+ label="Number of Inference Steps"
527
+ )
528
+ guidance_scale = gr.Slider(
529
+ minimum=0.1,
530
+ maximum=10.0,
531
+ value=4,
532
+ step=0.1,
533
+ label="Guidance Scale"
534
+ )
535
+ pag_guidance_scale = gr.Slider(
536
+ minimum=0.1,
537
+ maximum=10.0,
538
+ value=1,
539
+ step=0.1,
540
+ label="PAG guidance Scale"
541
+ )
542
+ height = gr.Number(value=1024, label="Height")
543
+ width = gr.Number(value=1024, label="Width")
544
+
545
+ with gr.Accordion("Advanced Configuration", open=False):
546
+ with gr.Row():
547
+ denoise_mask_checkbox = gr.CheckboxGroup(
548
+ label="Denoise Mask",
549
+ choices=["image_0"],
550
+ value=["image_0"]
551
+ )
552
+ azimuth = gr.Textbox(
553
+ value="0",
554
+ label="Azimuths (degrees, comma-separated, 'None' for missing)"
555
+ )
556
+ elevation = gr.Textbox(
557
+ value="0",
558
+ label="Elevations (degrees, comma-separated, 'None' for missing)"
559
+ )
560
+ distance = gr.Textbox(
561
+ value="1.5",
562
+ label="Distances (comma-separated, 'None' for missing)"
563
+ )
564
+ focal_length = gr.Number(
565
+ value=1.3887,
566
+ label="Focal Length of camera for multiview generation"
567
+ )
568
+ scale_factor = gr.Number(value=1.0, label="Scale Factor")
569
+ scale_watershed = gr.Number(value=1.0, label="Scale Watershed")
570
+ noise_scale = gr.Number(value=1.0, label="Noise Scale") # Added noise_scale input
571
+
572
+ output_images = gr.Gallery(
573
+ label="Generated Images",
574
+ show_label=True,
575
+ columns=4,
576
+ rows=2,
577
+ height="auto",
578
+ object_fit="contain"
579
+ )
580
+
581
+ with gr.Column():
582
+ generate_btn = gr.Button("Generate Image")
583
+ apply_mask_btn = gr.Button("Apply Mask")
584
+
585
+ status = gr.Textbox(label="Generation Status")
586
+
587
+ # Event Handlers
588
+ def update_gallery(files, images_state):
589
+ if not files:
590
+ return images_state, images_state
591
+
592
+ new_images = []
593
+ for file in files:
594
+ try:
595
+ # Handle both file paths and file objects
596
+ if isinstance(file, dict): # For drag and drop files
597
+ file = file['path']
598
+ elif hasattr(file, 'name'): # For uploaded files
599
+ file = file.name
600
+
601
+ img = Image.open(file).convert('RGB')
602
+ new_images.append(img)
603
+ except Exception as e:
604
+ print(f"Error loading image: {str(e)}")
605
+ continue
606
+
607
+ images_state.extend(new_images)
608
+ return images_state, images_state
609
 
610
+ def on_image_select(evt: gr.SelectData, selected_indices_state):
611
+ selected_indices = selected_indices_state or []
612
+ index = evt.index
613
+ if index in selected_indices:
614
+ selected_indices.remove(index)
615
+ else:
616
+ selected_indices.append(index)
617
+ return selected_indices
618
+
619
+ def delete_images(selected_indices, images_state):
620
+ updated_images = [img for i, img in enumerate(images_state) if i not in selected_indices]
621
+ selected_indices_state = []
622
+ return updated_images, updated_images, selected_indices_state
623
+
624
+ def delete_all_images(images_state):
625
+ updated_images = []
626
+ selected_indices_state = []
627
+ return updated_images, updated_images, selected_indices_state
628
+
629
+ def update_height_width(images_state):
630
+ if images_state:
631
+ closest_ar = get_closest_ratio(
632
+ height=images_state[0].size[1],
633
+ width=images_state[0].size[0],
634
+ ratios=ASPECT_RATIO_512
635
+ )
636
+ height_val, width_val = int(closest_ar[0][0]), int(closest_ar[0][1])
637
+ else:
638
+ height_val, width_val = 1024, 1024 # Default values
639
+ return gr.update(value=height_val), gr.update(value=width_val)
640
+
641
+ # Connect events
642
+ file_output.change(
643
+ fn=update_gallery,
644
+ inputs=[file_output, images_state],
645
+ outputs=[images_state, gallery]
646
+ ).then(
647
+ fn=update_height_width,
648
+ inputs=[images_state],
649
+ outputs=[height, width]
650
+ ).then(
651
+ fn=update_denoise_checkboxes,
652
+ inputs=[images_state, task_type, azimuth, elevation, distance],
653
+ outputs=[denoise_mask_checkbox]
654
+ )
655
+
656
+ gallery.select(
657
+ fn=on_image_select,
658
+ inputs=[selected_indices_state],
659
+ outputs=[selected_indices_state]
660
+ )
661
+
662
+ delete_button.click(
663
+ fn=delete_images,
664
+ inputs=[selected_indices_state, images_state],
665
+ outputs=[images_state, gallery, selected_indices_state]
666
+ ).then(
667
+ fn=update_denoise_checkboxes,
668
+ inputs=[images_state, task_type, azimuth, elevation, distance],
669
+ outputs=[denoise_mask_checkbox]
670
+ )
671
+
672
+ delete_all_button.click(
673
+ fn=delete_all_images,
674
+ inputs=[images_state],
675
+ outputs=[images_state, gallery, selected_indices_state]
676
+ ).then(
677
+ fn=update_denoise_checkboxes,
678
+ inputs=[images_state, task_type, azimuth, elevation, distance],
679
+ outputs=[denoise_mask_checkbox]
680
+ )
681
+
682
+ task_type.change(
683
+ fn=update_denoise_checkboxes,
684
+ inputs=[images_state, task_type, azimuth, elevation, distance],
685
+ outputs=[denoise_mask_checkbox]
686
+ )
687
+
688
+ azimuth.change(
689
+ fn=update_denoise_checkboxes,
690
+ inputs=[images_state, task_type, azimuth, elevation, distance],
691
+ outputs=[denoise_mask_checkbox]
692
+ )
693
+
694
+ elevation.change(
695
+ fn=update_denoise_checkboxes,
696
+ inputs=[images_state, task_type, azimuth, elevation, distance],
697
+ outputs=[denoise_mask_checkbox]
698
+ )
699
+
700
+ distance.change(
701
+ fn=update_denoise_checkboxes,
702
+ inputs=[images_state, task_type, azimuth, elevation, distance],
703
+ outputs=[denoise_mask_checkbox]
704
+ )
705
+
706
+ generate_btn.click(
707
+ fn=generate_image,
708
  inputs=[
709
+ images_state, prompt, negative_prompt, num_steps, guidance_scale, pag_guidance_scale,
710
+ denoise_mask_checkbox, task_type, azimuth, elevation, distance,
711
+ focal_length, height, width, scale_factor, scale_watershed, noise_scale # Added noise_scale here
 
 
 
 
 
712
  ],
713
+ outputs=[output_images, status]
714
+ )
715
+
716
+ auto_caption_btn.click(
717
+ fn=update_prompt,
718
+ inputs=[images_state, task_type, molmo_message],
719
+ outputs=[prompt, caption_status]
720
+ )
721
+
722
+ apply_mask_btn.click(
723
+ fn=apply_mask,
724
+ inputs=[images_state],
725
+ outputs=[output_images, status]
726
  )
727
 
728
  if __name__ == "__main__":
729
+ demo.launch(share=True)