Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,154 +1,729 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
|
5 |
-
#
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
|
10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
-
if torch.cuda.is_available():
|
13 |
-
torch_dtype = torch.float16
|
14 |
-
else:
|
15 |
-
torch_dtype = torch.float32
|
16 |
|
17 |
-
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
):
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
generator=generator,
|
49 |
-
).images[0]
|
50 |
-
|
51 |
-
return image, seed
|
52 |
-
|
53 |
-
|
54 |
-
examples = [
|
55 |
-
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
|
56 |
-
"An astronaut riding a green horse",
|
57 |
-
"A delicious ceviche cheesecake slice",
|
58 |
-
]
|
59 |
-
|
60 |
-
css = """
|
61 |
-
#col-container {
|
62 |
-
margin: 0 auto;
|
63 |
-
max-width: 640px;
|
64 |
-
}
|
65 |
-
"""
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
|
|
70 |
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
max_lines=1,
|
76 |
-
placeholder="Enter your prompt",
|
77 |
-
container=False,
|
78 |
-
)
|
79 |
|
80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
81 |
|
82 |
-
|
|
|
|
|
|
|
83 |
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
91 |
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
)
|
|
|
|
|
|
|
99 |
|
100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
|
102 |
-
|
103 |
-
width = gr.Slider(
|
104 |
-
label="Width",
|
105 |
-
minimum=256,
|
106 |
-
maximum=MAX_IMAGE_SIZE,
|
107 |
-
step=32,
|
108 |
-
value=1024, # Replace with defaults that work for your model
|
109 |
-
)
|
110 |
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
with gr.Row():
|
120 |
-
|
121 |
-
label="
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
126 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
inputs=[
|
141 |
-
prompt,
|
142 |
-
|
143 |
-
|
144 |
-
randomize_seed,
|
145 |
-
width,
|
146 |
-
height,
|
147 |
-
guidance_scale,
|
148 |
-
num_inference_steps,
|
149 |
],
|
150 |
-
outputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
)
|
152 |
|
153 |
if __name__ == "__main__":
|
154 |
-
demo.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import base64
|
4 |
+
import io
|
5 |
+
from PIL import Image
|
6 |
+
from transformers import (
|
7 |
+
LlavaNextProcessor, LlavaNextForConditionalGeneration,
|
8 |
+
T5EncoderModel, T5Tokenizer
|
9 |
+
)
|
10 |
+
from transformers import (
|
11 |
+
AutoProcessor, AutoModelForCausalLM, GenerationConfig,
|
12 |
+
T5EncoderModel, T5Tokenizer
|
13 |
+
)
|
14 |
+
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler, FlowMatchHeunDiscreteScheduler, FluxPipeline
|
15 |
+
from tordi.diffusion.pipelines.onediffusion import OneDiffusionPipeline
|
16 |
+
from tordi.models.denoiser.nextdit import NextDiT
|
17 |
+
from tordi.dataset.utils import get_closest_ratio, ASPECT_RATIO_512
|
18 |
+
from typing import List, Optional
|
19 |
+
|
20 |
+
# Import additional libraries
|
21 |
+
import matplotlib
|
22 |
import numpy as np
|
23 |
+
import cv2
|
24 |
|
25 |
+
# Task-specific tokens
|
26 |
+
TASK2SPECIAL_TOKENS = {
|
27 |
+
"text2image": "[[text2image]]",
|
28 |
+
"deblurring": "[[deblurring]]",
|
29 |
+
"inpainting": "[[image_inpainting]]",
|
30 |
+
"canny": "[[canny2image]]",
|
31 |
+
"super_resolution": "[[super_resolution]]",
|
32 |
+
"depth2image": "[[depth2image]]",
|
33 |
+
"hed2image": "[[hed2img]]",
|
34 |
+
"pose2image": "[[pose2image]]",
|
35 |
+
"semanticmap2image": "[[semanticmap2image]]",
|
36 |
+
"boundingbox2image": "[[boundingbox2image]]",
|
37 |
+
"image_editing": "[[image_editing]]",
|
38 |
+
"faceid": "[[faceid]]",
|
39 |
+
"multiview": "[[multiview]]",
|
40 |
+
"subject_driven": "[[subject_driven]]"
|
41 |
+
}
|
42 |
+
NEGATIVE_PROMPT = "monochrome, greyscale, low-res, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry, artist name, poorly drawn, bad anatomy, wrong anatomy, extra limb, missing limb, floating limbs, disconnected limbs, mutation, mutated, ugly, disgusting, blurry, amputation"
|
43 |
+
|
44 |
+
|
45 |
+
class LlavaCaptionProcessor:
|
46 |
+
def __init__(self):
|
47 |
+
model_name = "llava-hf/llama3-llava-next-8b-hf"
|
48 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
49 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
50 |
+
self.processor = LlavaNextProcessor.from_pretrained(model_name)
|
51 |
+
self.model = LlavaNextForConditionalGeneration.from_pretrained(
|
52 |
+
model_name, torch_dtype=dtype, low_cpu_mem_usage=True
|
53 |
+
).to(device)
|
54 |
+
self.SPECIAL_TOKENS = "assistant\n\n\n"
|
55 |
+
|
56 |
+
def generate_response(self, image: Image.Image, msg: str) -> str:
|
57 |
+
conversation = [{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": msg}]}]
|
58 |
+
with torch.no_grad():
|
59 |
+
prompt = self.processor.apply_chat_template(conversation, add_generation_prompt=True)
|
60 |
+
inputs = self.processor(prompt, image, return_tensors="pt").to(self.model.device)
|
61 |
+
output = self.model.generate(**inputs, max_new_tokens=250)
|
62 |
+
response = self.processor.decode(output[0], skip_special_tokens=True)
|
63 |
+
return response.split(msg)[-1].strip()[len(self.SPECIAL_TOKENS):]
|
64 |
|
65 |
+
def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
|
66 |
+
if msg is None:
|
67 |
+
msg = f"Describe the contents of the photo in 150 words or fewer."
|
68 |
+
try:
|
69 |
+
return [self.generate_response(img, msg) for img in images]
|
70 |
+
except Exception as e:
|
71 |
+
print(f"Error in process: {str(e)}")
|
72 |
+
raise
|
73 |
|
|
|
|
|
|
|
|
|
74 |
|
75 |
+
class MolmoCaptionProcessor:
|
76 |
+
def __init__(self):
|
77 |
+
pretrained_model_name = 'allenai/Molmo-7B-O-0924'
|
78 |
+
self.processor = AutoProcessor.from_pretrained(
|
79 |
+
pretrained_model_name,
|
80 |
+
trust_remote_code=True,
|
81 |
+
torch_dtype='auto',
|
82 |
+
device_map='auto'
|
83 |
+
)
|
84 |
+
self.model = AutoModelForCausalLM.from_pretrained(
|
85 |
+
pretrained_model_name,
|
86 |
+
trust_remote_code=True,
|
87 |
+
torch_dtype='auto',
|
88 |
+
device_map='auto'
|
89 |
+
)
|
90 |
|
91 |
+
def generate_response(self, image: Image.Image, msg: str) -> str:
|
92 |
+
inputs = self.processor.process(
|
93 |
+
images=[image],
|
94 |
+
text=msg
|
95 |
+
)
|
96 |
+
# Move inputs to the correct device and make a batch of size 1
|
97 |
+
inputs = {k: v.to(self.model.device).unsqueeze(0) for k, v in inputs.items()}
|
98 |
+
|
99 |
+
# Generate output
|
100 |
+
output = self.model.generate_from_batch(
|
101 |
+
inputs,
|
102 |
+
GenerationConfig(max_new_tokens=250, stop_strings="<|endoftext|>"),
|
103 |
+
tokenizer=self.processor.tokenizer
|
104 |
+
)
|
105 |
+
|
106 |
+
# Only get generated tokens and decode them to text
|
107 |
+
generated_tokens = output[0, inputs['input_ids'].size(1):]
|
108 |
+
return self.processor.tokenizer.decode(generated_tokens, skip_special_tokens=True).strip()
|
109 |
|
110 |
|
111 |
+
def process(self, images: List[Image.Image], msg: str = None) -> List[str]:
|
112 |
+
if msg is None:
|
113 |
+
msg = f"Describe the contents of the photo in 150 words or fewer."
|
114 |
+
try:
|
115 |
+
return [self.generate_response(img, msg) for img in images]
|
116 |
+
except Exception as e:
|
117 |
+
print(f"Error in process: {str(e)}")
|
118 |
+
raise
|
119 |
+
|
120 |
+
def initialize_models():
|
121 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
122 |
+
model = NextDiT.from_pretrained(
|
123 |
+
# "/data/input/duongl/finetuning_distributed_multiview_16x8_scalerays_dl3dv_dynamic_shift_softcap_editing/checkpoint-98000",
|
124 |
+
# "/data/input/duongl/data/input/duongl/finetuning_distributed_multiview_16x8_scalerays_dl3dv_dynamic_shift_softcap_trainingWithFluxScheduler/checkpoint-10000/", # "lehduong/OneDiffusion",
|
125 |
+
"lehduong/OneDiffusion",
|
126 |
+
subfolder="transformer",
|
127 |
+
torch_dtype=torch.float32,
|
128 |
+
).to(device)
|
129 |
+
vae = AutoencoderKL.from_pretrained("lehduong/OneDiffusion", subfolder="vae").to(device)
|
130 |
+
text_encoder = T5EncoderModel.from_pretrained("lehduong/OneDiffusion", subfolder="text_encoder", torch_dtype=torch.float16).to(device)
|
131 |
+
tokenizer = T5Tokenizer.from_pretrained("lehduong/OneDiffusion", subfolder="tokenizer")
|
132 |
+
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
|
133 |
+
# "stabilityai/stable-diffusion-3-medium-diffusers",
|
134 |
+
# "black-forest-labs/FLUX.1-dev",
|
135 |
+
"lehduong/OneDiffusion",
|
136 |
+
subfolder="scheduler"
|
137 |
+
)
|
138 |
+
# scheduler = FlowMatchEulerDiscreteScheduler(
|
139 |
+
# base_image_seq_len=256,
|
140 |
+
# base_shift=0.5,
|
141 |
+
# max_image_seq_len=4096,
|
142 |
+
# max_shift=1.16,
|
143 |
+
# num_train_timesteps=1000,
|
144 |
+
# shift=3.0,
|
145 |
+
# use_dynamic_shifting=True
|
146 |
+
# )
|
147 |
+
pipeline = OneDiffusionPipeline(
|
148 |
+
vae=vae, text_encoder=text_encoder, transformer=model, tokenizer=tokenizer, scheduler=scheduler
|
149 |
+
).to(torch.bfloat16)
|
150 |
+
molmo_caption_processor = MolmoCaptionProcessor() # LlavaCaptionProcessor()
|
151 |
+
return pipeline, molmo_caption_processor
|
152 |
+
|
153 |
+
def colorize_depth_maps(
|
154 |
+
depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None
|
155 |
):
|
156 |
+
"""
|
157 |
+
Colorize depth maps with reversed colors.
|
158 |
+
"""
|
159 |
+
assert len(depth_map.shape) >= 2, "Invalid dimension"
|
160 |
+
|
161 |
+
if isinstance(depth_map, torch.Tensor):
|
162 |
+
depth = depth_map.detach().squeeze().numpy()
|
163 |
+
elif isinstance(depth_map, np.ndarray):
|
164 |
+
depth = depth_map.copy().squeeze()
|
165 |
+
# reshape to [ (B,) H, W ]
|
166 |
+
if depth.ndim < 3:
|
167 |
+
depth = depth[np.newaxis, :, :]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
168 |
|
169 |
+
# Normalize depth values to [0, 1]
|
170 |
+
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1)
|
171 |
+
# Invert the depth values to reverse the colors
|
172 |
+
depth = 1 - depth
|
173 |
|
174 |
+
# Use the colormap
|
175 |
+
cm = matplotlib.colormaps[cmap]
|
176 |
+
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] # values from 0 to 1
|
177 |
+
img_colored_np = np.rollaxis(img_colored_np, 3, 1)
|
|
|
|
|
|
|
|
|
178 |
|
179 |
+
if valid_mask is not None:
|
180 |
+
if isinstance(depth_map, torch.Tensor):
|
181 |
+
valid_mask = valid_mask.detach().numpy()
|
182 |
+
valid_mask = valid_mask.squeeze() # [H, W] or [B, H, W]
|
183 |
+
if valid_mask.ndim < 3:
|
184 |
+
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :]
|
185 |
+
else:
|
186 |
+
valid_mask = valid_mask[:, np.newaxis, :, :]
|
187 |
+
valid_mask = np.repeat(valid_mask, 3, axis=1)
|
188 |
+
img_colored_np[~valid_mask] = 0
|
189 |
|
190 |
+
if isinstance(depth_map, torch.Tensor):
|
191 |
+
img_colored = torch.from_numpy(img_colored_np).float()
|
192 |
+
elif isinstance(depth_map, np.ndarray):
|
193 |
+
img_colored = img_colored_np
|
194 |
|
195 |
+
return img_colored
|
196 |
+
|
197 |
+
|
198 |
+
def format_prompt(task_type: str, captions: List[str]) -> str:
|
199 |
+
if not captions:
|
200 |
+
return ""
|
201 |
+
if task_type == "faceid":
|
202 |
+
img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions, start=1)]
|
203 |
+
return f"[[faceid]] [[img0]] insert/your/caption/here {' '.join(img_prompts)}"
|
204 |
+
elif task_type == "image_editing":
|
205 |
+
target_caption = captions[0] if len(captions) >= 2 else "Insert target caption here"
|
206 |
+
source_caption = captions[1] if len(captions) >= 2 else captions[0]
|
207 |
+
return f"[[image_editing]] [[target_caption]] {target_caption} [[source_caption]] {source_caption}"
|
208 |
+
elif task_type == "semanticmap2image":
|
209 |
+
return f"[[semanticmap2image]] <#00ffff Cyan mask: insert/concept/to/segment/here> {captions[0]}"
|
210 |
+
elif task_type == "boundingbox2image":
|
211 |
+
return f"[[boundingbox2image]] <#00ffff Cyan boundingbox: insert/concept/to/segment/here> {captions[0]}"
|
212 |
+
elif task_type == "multiview":
|
213 |
+
# img_prompts = [f"[[img{i}]] {caption}" for i, caption in enumerate(captions)]
|
214 |
+
img_prompts = captions[0]
|
215 |
+
return f"[[multiview]] {img_prompts}"
|
216 |
+
elif task_type == "subject_driven":
|
217 |
+
return f"[[subject_driven]] <item: insert/item/here> [[img0]] insert/your/target/caption/here [[img1]] {captions[0]}"
|
218 |
+
else:
|
219 |
+
return f"{TASK2SPECIAL_TOKENS[task_type]} {captions[0]}"
|
220 |
+
|
221 |
+
def update_prompt(images: List[Image.Image], task_type: str, custom_msg: str = None):
|
222 |
+
if not images:
|
223 |
+
return format_prompt(task_type, []), "Please upload at least one image!"
|
224 |
+
try:
|
225 |
+
captions = molmo_processor.process(images, custom_msg)
|
226 |
+
if not captions:
|
227 |
+
return "", "No valid images found!"
|
228 |
+
prompt = format_prompt(task_type, captions)
|
229 |
+
return prompt, f"Generated {len(captions)} captions successfully!"
|
230 |
+
except Exception as e:
|
231 |
+
return "", f"Error generating captions: {str(e)}"
|
232 |
+
|
233 |
+
def generate_image(images: List[Image.Image], prompt: str, negative_prompt: str, num_inference_steps: int, guidance_scale: float, pag_guidance_scale: float,
|
234 |
+
denoise_mask: List[str], task_type: str, azimuth: str, elevation: str, distance: str, focal_length: float,
|
235 |
+
height: int = 1024, width: int = 1024, scale_factor: float = 1.0, scale_watershed: float = 1.0,
|
236 |
+
noise_scale: float = None, progress=gr.Progress()):
|
237 |
+
try:
|
238 |
+
img2img_kwargs = {
|
239 |
+
'prompt': prompt,
|
240 |
+
'negative_prompt': negative_prompt,
|
241 |
+
'num_inference_steps': num_inference_steps,
|
242 |
+
'guidance_scale': guidance_scale,
|
243 |
+
'height': height,
|
244 |
+
'width': width,
|
245 |
+
'forward_kwargs': {
|
246 |
+
'scale_factor': scale_factor,
|
247 |
+
'scale_watershed': scale_watershed
|
248 |
+
},
|
249 |
+
'noise_scale': noise_scale # Added noise_scale here
|
250 |
+
}
|
251 |
+
|
252 |
+
if task_type == 'multiview':
|
253 |
+
# Parse azimuth, elevation, and distance into lists, allowing 'None' values
|
254 |
+
azimuths = [float(a.strip()) if a.strip().lower() != 'none' else None for a in azimuth.split(',')] if azimuth else []
|
255 |
+
elevations = [float(e.strip()) if e.strip().lower() != 'none' else None for e in elevation.split(',')] if elevation else []
|
256 |
+
distances = [float(d.strip()) if d.strip().lower() != 'none' else None for d in distance.split(',')] if distance else []
|
257 |
+
|
258 |
+
num_views = max(len(images), len(azimuths), len(elevations), len(distances))
|
259 |
+
if num_views == 0:
|
260 |
+
return None, "At least one image or camera parameter must be provided."
|
261 |
+
|
262 |
+
total_components = []
|
263 |
+
for i in range(num_views):
|
264 |
+
total_components.append(f"image_{i}")
|
265 |
+
total_components.append(f"camera_pose_{i}")
|
266 |
+
|
267 |
+
denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
|
268 |
+
|
269 |
+
if len(denoise_mask_int) != len(total_components):
|
270 |
+
return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
|
271 |
+
|
272 |
+
# Pad the input lists to num_views length
|
273 |
+
images_padded = images + [] * (num_views - len(images)) # Do not add None
|
274 |
+
azimuths_padded = azimuths + [None] * (num_views - len(azimuths))
|
275 |
+
elevations_padded = elevations + [None] * (num_views - len(elevations))
|
276 |
+
distances_padded = distances + [None] * (num_views - len(distances))
|
277 |
+
|
278 |
+
# Prepare values
|
279 |
+
img2img_kwargs.update({
|
280 |
+
'image': images_padded,
|
281 |
+
'multiview_azimuths': azimuths_padded,
|
282 |
+
'multiview_elevations': elevations_padded,
|
283 |
+
'multiview_distances': distances_padded,
|
284 |
+
'multiview_focal_length': focal_length, # Pass focal_length here
|
285 |
+
'is_multiview': True,
|
286 |
+
'denoise_mask': denoise_mask_int,
|
287 |
+
# 'predict_camera_poses': True,
|
288 |
+
})
|
289 |
+
else:
|
290 |
+
total_components = ["image_0"] + [f"image_{i+1}" for i in range(len(images))]
|
291 |
+
denoise_mask_int = [1 if comp in denoise_mask else 0 for comp in total_components]
|
292 |
+
if len(denoise_mask_int) != len(total_components):
|
293 |
+
return None, f"Denoise mask length mismatch: expected {len(total_components)} components."
|
294 |
|
295 |
+
img2img_kwargs.update({
|
296 |
+
'image': images,
|
297 |
+
'denoise_mask': denoise_mask_int
|
298 |
+
})
|
299 |
+
|
300 |
+
progress(0, desc="Generating image...")
|
301 |
+
if task_type == 'text2image':
|
302 |
+
output = pipeline(
|
303 |
+
prompt=prompt,
|
304 |
+
negative_prompt=negative_prompt,
|
305 |
+
num_inference_steps=num_inference_steps,
|
306 |
+
guidance_scale=guidance_scale,
|
307 |
+
pag_guidance_scale=pag_guidance_scale,
|
308 |
+
height=height,
|
309 |
+
width=width,
|
310 |
+
scale_factor=scale_factor,
|
311 |
+
scale_watershed=scale_watershed,
|
312 |
+
noise_scale=noise_scale # Added noise_scale here
|
313 |
)
|
314 |
+
else:
|
315 |
+
output = pipeline.img2img(**img2img_kwargs)
|
316 |
+
progress(1, desc="Done!")
|
317 |
|
318 |
+
# Process the output images if task is 'depth2image' and predicting depth
|
319 |
+
if task_type == 'depth2image' and denoise_mask_int[-1] == 1:
|
320 |
+
processed_images = []
|
321 |
+
for img in output.images:
|
322 |
+
depth_map = np.array(img.convert('L')) # Convert to grayscale numpy array
|
323 |
+
min_depth = depth_map.min()
|
324 |
+
max_depth = depth_map.max()
|
325 |
+
colorized = colorize_depth_maps(depth_map, min_depth, max_depth)[0]
|
326 |
+
colorized = np.transpose(colorized, (1, 2, 0))
|
327 |
+
colorized = (colorized * 255).astype(np.uint8)
|
328 |
+
img_colorized = Image.fromarray(colorized)
|
329 |
+
processed_images.append(img_colorized)
|
330 |
+
output_images = processed_images + output.images
|
331 |
+
elif task_type in ['boundingbox2image', 'semanticmap2image'] and denoise_mask_int == [0,1] and images:
|
332 |
+
# Interpolate between input and output images
|
333 |
+
processed_images = []
|
334 |
+
for input_img, output_img in zip(images, output.images):
|
335 |
+
input_img_resized = input_img.resize(output_img.size)
|
336 |
+
blended_img = Image.blend(input_img_resized, output_img, alpha=0.5)
|
337 |
+
processed_images.append(blended_img)
|
338 |
+
output_images = processed_images + output.images
|
339 |
+
else:
|
340 |
+
output_images = output.images
|
341 |
|
342 |
+
return output_images, "Generation completed successfully!"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
|
344 |
+
except Exception as e:
|
345 |
+
return None, f"Error during generation: {str(e)}"
|
346 |
+
|
347 |
+
def update_denoise_checkboxes(images_state: List[Image.Image], task_type: str, azimuth: str, elevation: str, distance: str):
|
348 |
+
if task_type == 'multiview':
|
349 |
+
azimuths = [a.strip() for a in azimuth.split(',')] if azimuth else []
|
350 |
+
elevations = [e.strip() for e in elevation.split(',')] if elevation else []
|
351 |
+
distances = [d.strip() for d in distance.split(',')] if distance else []
|
352 |
+
images_len = len(images_state)
|
353 |
+
|
354 |
+
num_views = max(images_len, len(azimuths), len(elevations), len(distances))
|
355 |
+
if num_views == 0:
|
356 |
+
return gr.update(choices=[], value=[]), "Please provide at least one image or camera parameter."
|
357 |
+
|
358 |
+
# Pad lists to the same length
|
359 |
+
azimuths += ['None'] * (num_views - len(azimuths))
|
360 |
+
elevations += ['None'] * (num_views - len(elevations))
|
361 |
+
distances += ['None'] * (num_views - len(distances))
|
362 |
+
# Do not add None to images_state
|
363 |
+
|
364 |
+
labels = []
|
365 |
+
values = []
|
366 |
+
for i in range(num_views):
|
367 |
+
labels.append(f"image_{i}")
|
368 |
+
labels.append(f"camera_pose_{i}")
|
369 |
+
|
370 |
+
# Default behavior: condition on provided inputs, generate missing ones
|
371 |
+
if i >= images_len:
|
372 |
+
values.append(f"image_{i}")
|
373 |
+
if azimuths[i].lower() == 'none' or elevations[i].lower() == 'none' or distances[i].lower() == 'none':
|
374 |
+
values.append(f"camera_pose_{i}")
|
375 |
+
|
376 |
+
return gr.update(choices=labels, value=values)
|
377 |
+
else:
|
378 |
+
labels = ["image_0"] + [f"image_{i+1}" for i in range(len(images_state))]
|
379 |
+
values = ["image_0"]
|
380 |
+
return gr.update(choices=labels, value=values)
|
381 |
+
|
382 |
+
def apply_mask(images_state):
|
383 |
+
if len(images_state) < 2:
|
384 |
+
return None, "Please upload at least two images: first as the base image, second as the mask."
|
385 |
+
base_img = images_state[0]
|
386 |
+
mask_img = images_state[1]
|
387 |
+
|
388 |
+
# Convert images to arrays
|
389 |
+
base_arr = np.array(base_img)
|
390 |
+
mask_arr = np.array(mask_img)
|
391 |
+
|
392 |
+
# Convert mask to grayscale
|
393 |
+
if mask_arr.ndim == 3:
|
394 |
+
gray_mask = cv2.cvtColor(mask_arr, cv2.COLOR_RGB2GRAY)
|
395 |
+
else:
|
396 |
+
gray_mask = mask_arr
|
397 |
+
|
398 |
+
# Create a binary mask where non-black pixels are True
|
399 |
+
binary_mask = gray_mask > 10
|
400 |
+
|
401 |
+
# Define the gray color
|
402 |
+
gray_color = np.array([128, 128, 128], dtype=np.uint8)
|
403 |
+
|
404 |
+
# Apply gray color where mask is True
|
405 |
+
masked_arr = base_arr.copy()
|
406 |
+
masked_arr[binary_mask] = gray_color
|
407 |
+
|
408 |
+
masked_img = Image.fromarray(masked_arr)
|
409 |
+
return [masked_img], "Mask applied successfully!"
|
410 |
+
|
411 |
+
def process_images_for_task_type(images_state: List[Image.Image], task_type: str):
|
412 |
+
# No changes needed here since we are processing the output images
|
413 |
+
return images_state, images_state
|
414 |
+
|
415 |
+
# Initialize models
|
416 |
+
pipeline, molmo_processor = initialize_models()
|
417 |
+
|
418 |
+
with gr.Blocks(title="OneDiffusion Demo") as demo:
|
419 |
+
gr.Markdown("""
|
420 |
+
# OneDiffusion Demo
|
421 |
+
|
422 |
+
**Welcome to the OneDiffusion Demo!**
|
423 |
+
|
424 |
+
This application allows you to generate images based on your input prompts for various tasks. Here's how to use it:
|
425 |
+
|
426 |
+
1. **Select Task Type**: Choose the type of task you want to perform from the "Task Type" dropdown menu.
|
427 |
+
|
428 |
+
2. **Upload Images**: Drag and drop images directly onto the upload area, or click to select files from your device.
|
429 |
|
430 |
+
3. **Generate Captions**: **If you upload any images**, Click the "Generate Captions with Molmo" button to generate descriptive captions for your uploaded images (depend on the task). You can enter a custom message in the "Custom Message for Molmo" textbox e.g., "caption in 50 words" instead of 100 words.
|
431 |
+
|
432 |
+
4. **Configure Generation Settings**: Expand the "Advanced Configuration" section to adjust parameters like the number of inference steps, guidance scale, image size, and more.
|
433 |
+
|
434 |
+
5. **Generate Images**: After setting your preferences, click the "Generate Image" button. The generated images will appear in the "Generated Images" gallery.
|
435 |
+
|
436 |
+
6. **Manage Images**: Use the "Delete Selected Images" or "Delete All Images" buttons to remove unwanted images from the gallery.
|
437 |
+
|
438 |
+
**Notes**:
|
439 |
+
- For text-to-image:
|
440 |
+
+ simply enter your prompt in this format "[[text2image]] your/prompt/here" and press the "Generate Image" button.
|
441 |
+
|
442 |
+
- For boundingbox2image/semantic2image/inpainting etc tasks:
|
443 |
+
+ To perform condition-to-image such as semantic map to image, follow above steps
|
444 |
+
+ For image-to-condition e.g., image to depth, change the denoise_mask checkbox before generating images. You must UNCHECK image_0 box and CHECK image_1 box.
|
445 |
+
|
446 |
+
- For FaceID tasks:
|
447 |
+
+ Use 3 or 4 images if single input image does not give satisfactory results.
|
448 |
+
+ All images will be resized and center cropped to the input height and width. You should choose height and width so that faces in input images won't be cropped.
|
449 |
+
+ Model works best with close-up portrait (input and output) images.
|
450 |
+
+ If the model does not conform your text prompt, try using shorter caption for source image(s).
|
451 |
+
+ If you have non-human subjects and does not get satisfactory results, try "copying" part of caption of source images where it describes the properties of the subject e.g., a monster with red eyes, sharp teeth, etc.
|
452 |
+
|
453 |
+
- For Multiview generation:
|
454 |
+
+ Only support square images (ideally in 512x512 resolution).
|
455 |
+
+ Ensure the number of elevations, azimuths, and distances are equal.
|
456 |
+
+ The model generally works well for 2-5 views (include both input and generated images). Since the model is trained with 3 views on 512x512 resolution, you might try scale_factor of [1.1; 1.5] and scale_watershed of [100; 400] for better extrapolation.
|
457 |
+
+ For better results:
|
458 |
+
1) try increasing num_inference_steps to 75-100.
|
459 |
+
2) avoid aggressively changes in target camera poses, for example to generate novel views at azimuth of 180, (simultaneously) generate 4 views with azimuth of 45, 90, 135, 180.
|
460 |
+
|
461 |
+
Enjoy creating images with OneDiffusion!
|
462 |
+
""")
|
463 |
+
|
464 |
+
with gr.Row():
|
465 |
+
with gr.Column():
|
466 |
+
images_state = gr.State([])
|
467 |
+
selected_indices_state = gr.State([])
|
468 |
+
|
469 |
with gr.Row():
|
470 |
+
gallery = gr.Gallery(
|
471 |
+
label="Input Images",
|
472 |
+
show_label=True,
|
473 |
+
columns=2,
|
474 |
+
rows=2,
|
475 |
+
height="auto",
|
476 |
+
object_fit="contain"
|
477 |
)
|
478 |
+
|
479 |
+
# In the UI section, update the file_output component:
|
480 |
+
file_output = gr.File(
|
481 |
+
file_count="multiple",
|
482 |
+
file_types=["image"],
|
483 |
+
label="Drag and drop images here or click to upload",
|
484 |
+
height=100,
|
485 |
+
scale=2,
|
486 |
+
type="filepath" # Add this parameter
|
487 |
+
)
|
488 |
+
|
489 |
+
with gr.Row():
|
490 |
+
delete_button = gr.Button("Delete Selected Images")
|
491 |
+
delete_all_button = gr.Button("Delete All Images")
|
492 |
+
|
493 |
+
task_type = gr.Dropdown(
|
494 |
+
choices=list(TASK2SPECIAL_TOKENS.keys()),
|
495 |
+
value="text2image",
|
496 |
+
label="Task Type"
|
497 |
+
)
|
498 |
+
|
499 |
+
molmo_message = gr.Textbox(
|
500 |
+
lines=2,
|
501 |
+
value="Describe the contents of the photo in 100 words.",
|
502 |
+
label="Custom message for Molmo captioner"
|
503 |
+
)
|
504 |
+
|
505 |
+
auto_caption_btn = gr.Button("Generate Captions with Molmo")
|
506 |
|
507 |
+
with gr.Column():
|
508 |
+
prompt = gr.Textbox(
|
509 |
+
lines=3,
|
510 |
+
placeholder="Enter your prompt here or use auto-caption...",
|
511 |
+
label="Prompt"
|
512 |
+
)
|
513 |
+
negative_prompt = gr.Textbox(
|
514 |
+
lines=3,
|
515 |
+
value=NEGATIVE_PROMPT,
|
516 |
+
placeholder="Enter negative prompt here...",
|
517 |
+
label="Negative Prompt"
|
518 |
+
)
|
519 |
+
caption_status = gr.Textbox(label="Caption Status")
|
520 |
+
|
521 |
+
num_steps = gr.Slider(
|
522 |
+
minimum=1,
|
523 |
+
maximum=200,
|
524 |
+
value=30,
|
525 |
+
step=1,
|
526 |
+
label="Number of Inference Steps"
|
527 |
+
)
|
528 |
+
guidance_scale = gr.Slider(
|
529 |
+
minimum=0.1,
|
530 |
+
maximum=10.0,
|
531 |
+
value=4,
|
532 |
+
step=0.1,
|
533 |
+
label="Guidance Scale"
|
534 |
+
)
|
535 |
+
pag_guidance_scale = gr.Slider(
|
536 |
+
minimum=0.1,
|
537 |
+
maximum=10.0,
|
538 |
+
value=1,
|
539 |
+
step=0.1,
|
540 |
+
label="PAG guidance Scale"
|
541 |
+
)
|
542 |
+
height = gr.Number(value=1024, label="Height")
|
543 |
+
width = gr.Number(value=1024, label="Width")
|
544 |
+
|
545 |
+
with gr.Accordion("Advanced Configuration", open=False):
|
546 |
+
with gr.Row():
|
547 |
+
denoise_mask_checkbox = gr.CheckboxGroup(
|
548 |
+
label="Denoise Mask",
|
549 |
+
choices=["image_0"],
|
550 |
+
value=["image_0"]
|
551 |
+
)
|
552 |
+
azimuth = gr.Textbox(
|
553 |
+
value="0",
|
554 |
+
label="Azimuths (degrees, comma-separated, 'None' for missing)"
|
555 |
+
)
|
556 |
+
elevation = gr.Textbox(
|
557 |
+
value="0",
|
558 |
+
label="Elevations (degrees, comma-separated, 'None' for missing)"
|
559 |
+
)
|
560 |
+
distance = gr.Textbox(
|
561 |
+
value="1.5",
|
562 |
+
label="Distances (comma-separated, 'None' for missing)"
|
563 |
+
)
|
564 |
+
focal_length = gr.Number(
|
565 |
+
value=1.3887,
|
566 |
+
label="Focal Length of camera for multiview generation"
|
567 |
+
)
|
568 |
+
scale_factor = gr.Number(value=1.0, label="Scale Factor")
|
569 |
+
scale_watershed = gr.Number(value=1.0, label="Scale Watershed")
|
570 |
+
noise_scale = gr.Number(value=1.0, label="Noise Scale") # Added noise_scale input
|
571 |
+
|
572 |
+
output_images = gr.Gallery(
|
573 |
+
label="Generated Images",
|
574 |
+
show_label=True,
|
575 |
+
columns=4,
|
576 |
+
rows=2,
|
577 |
+
height="auto",
|
578 |
+
object_fit="contain"
|
579 |
+
)
|
580 |
+
|
581 |
+
with gr.Column():
|
582 |
+
generate_btn = gr.Button("Generate Image")
|
583 |
+
apply_mask_btn = gr.Button("Apply Mask")
|
584 |
+
|
585 |
+
status = gr.Textbox(label="Generation Status")
|
586 |
+
|
587 |
+
# Event Handlers
|
588 |
+
def update_gallery(files, images_state):
|
589 |
+
if not files:
|
590 |
+
return images_state, images_state
|
591 |
+
|
592 |
+
new_images = []
|
593 |
+
for file in files:
|
594 |
+
try:
|
595 |
+
# Handle both file paths and file objects
|
596 |
+
if isinstance(file, dict): # For drag and drop files
|
597 |
+
file = file['path']
|
598 |
+
elif hasattr(file, 'name'): # For uploaded files
|
599 |
+
file = file.name
|
600 |
+
|
601 |
+
img = Image.open(file).convert('RGB')
|
602 |
+
new_images.append(img)
|
603 |
+
except Exception as e:
|
604 |
+
print(f"Error loading image: {str(e)}")
|
605 |
+
continue
|
606 |
+
|
607 |
+
images_state.extend(new_images)
|
608 |
+
return images_state, images_state
|
609 |
|
610 |
+
def on_image_select(evt: gr.SelectData, selected_indices_state):
|
611 |
+
selected_indices = selected_indices_state or []
|
612 |
+
index = evt.index
|
613 |
+
if index in selected_indices:
|
614 |
+
selected_indices.remove(index)
|
615 |
+
else:
|
616 |
+
selected_indices.append(index)
|
617 |
+
return selected_indices
|
618 |
+
|
619 |
+
def delete_images(selected_indices, images_state):
|
620 |
+
updated_images = [img for i, img in enumerate(images_state) if i not in selected_indices]
|
621 |
+
selected_indices_state = []
|
622 |
+
return updated_images, updated_images, selected_indices_state
|
623 |
+
|
624 |
+
def delete_all_images(images_state):
|
625 |
+
updated_images = []
|
626 |
+
selected_indices_state = []
|
627 |
+
return updated_images, updated_images, selected_indices_state
|
628 |
+
|
629 |
+
def update_height_width(images_state):
|
630 |
+
if images_state:
|
631 |
+
closest_ar = get_closest_ratio(
|
632 |
+
height=images_state[0].size[1],
|
633 |
+
width=images_state[0].size[0],
|
634 |
+
ratios=ASPECT_RATIO_512
|
635 |
+
)
|
636 |
+
height_val, width_val = int(closest_ar[0][0]), int(closest_ar[0][1])
|
637 |
+
else:
|
638 |
+
height_val, width_val = 1024, 1024 # Default values
|
639 |
+
return gr.update(value=height_val), gr.update(value=width_val)
|
640 |
+
|
641 |
+
# Connect events
|
642 |
+
file_output.change(
|
643 |
+
fn=update_gallery,
|
644 |
+
inputs=[file_output, images_state],
|
645 |
+
outputs=[images_state, gallery]
|
646 |
+
).then(
|
647 |
+
fn=update_height_width,
|
648 |
+
inputs=[images_state],
|
649 |
+
outputs=[height, width]
|
650 |
+
).then(
|
651 |
+
fn=update_denoise_checkboxes,
|
652 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
653 |
+
outputs=[denoise_mask_checkbox]
|
654 |
+
)
|
655 |
+
|
656 |
+
gallery.select(
|
657 |
+
fn=on_image_select,
|
658 |
+
inputs=[selected_indices_state],
|
659 |
+
outputs=[selected_indices_state]
|
660 |
+
)
|
661 |
+
|
662 |
+
delete_button.click(
|
663 |
+
fn=delete_images,
|
664 |
+
inputs=[selected_indices_state, images_state],
|
665 |
+
outputs=[images_state, gallery, selected_indices_state]
|
666 |
+
).then(
|
667 |
+
fn=update_denoise_checkboxes,
|
668 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
669 |
+
outputs=[denoise_mask_checkbox]
|
670 |
+
)
|
671 |
+
|
672 |
+
delete_all_button.click(
|
673 |
+
fn=delete_all_images,
|
674 |
+
inputs=[images_state],
|
675 |
+
outputs=[images_state, gallery, selected_indices_state]
|
676 |
+
).then(
|
677 |
+
fn=update_denoise_checkboxes,
|
678 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
679 |
+
outputs=[denoise_mask_checkbox]
|
680 |
+
)
|
681 |
+
|
682 |
+
task_type.change(
|
683 |
+
fn=update_denoise_checkboxes,
|
684 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
685 |
+
outputs=[denoise_mask_checkbox]
|
686 |
+
)
|
687 |
+
|
688 |
+
azimuth.change(
|
689 |
+
fn=update_denoise_checkboxes,
|
690 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
691 |
+
outputs=[denoise_mask_checkbox]
|
692 |
+
)
|
693 |
+
|
694 |
+
elevation.change(
|
695 |
+
fn=update_denoise_checkboxes,
|
696 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
697 |
+
outputs=[denoise_mask_checkbox]
|
698 |
+
)
|
699 |
+
|
700 |
+
distance.change(
|
701 |
+
fn=update_denoise_checkboxes,
|
702 |
+
inputs=[images_state, task_type, azimuth, elevation, distance],
|
703 |
+
outputs=[denoise_mask_checkbox]
|
704 |
+
)
|
705 |
+
|
706 |
+
generate_btn.click(
|
707 |
+
fn=generate_image,
|
708 |
inputs=[
|
709 |
+
images_state, prompt, negative_prompt, num_steps, guidance_scale, pag_guidance_scale,
|
710 |
+
denoise_mask_checkbox, task_type, azimuth, elevation, distance,
|
711 |
+
focal_length, height, width, scale_factor, scale_watershed, noise_scale # Added noise_scale here
|
|
|
|
|
|
|
|
|
|
|
712 |
],
|
713 |
+
outputs=[output_images, status]
|
714 |
+
)
|
715 |
+
|
716 |
+
auto_caption_btn.click(
|
717 |
+
fn=update_prompt,
|
718 |
+
inputs=[images_state, task_type, molmo_message],
|
719 |
+
outputs=[prompt, caption_status]
|
720 |
+
)
|
721 |
+
|
722 |
+
apply_mask_btn.click(
|
723 |
+
fn=apply_mask,
|
724 |
+
inputs=[images_state],
|
725 |
+
outputs=[output_images, status]
|
726 |
)
|
727 |
|
728 |
if __name__ == "__main__":
|
729 |
+
demo.launch(share=True)
|