gokaygokay commited on
Commit
04eb2f6
1 Parent(s): 042efd8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +173 -130
app.py CHANGED
@@ -1,162 +1,205 @@
 
1
  import gradio as gr
2
- import numpy as np
3
- import random
4
- from diffusers import AuraFlowPipeline
5
  import torch
6
- import spaces
 
 
 
 
 
 
7
 
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
- #torch.set_float32_matmul_precision("high")
11
-
12
- #torch._inductor.config.conv_1x1_as_mm = True
13
- #torch._inductor.config.coordinate_descent_tuning = True
14
- #torch._inductor.config.epilogue_fusion = False
15
- #torch._inductor.config.coordinate_descent_check_all_directions = True
16
-
17
  pipe = AuraFlowPipeline.from_pretrained(
18
- "fal/AuraFlow",
19
  torch_dtype=torch.float16
20
- ).to("cuda")
21
 
22
- #pipe.transformer.to(memory_format=torch.channels_last)
23
- #pipe.vae.to(memory_format=torch.channels_last)
 
24
 
25
- #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
26
- #pipe.vae.decode = torch.compile(pipe.vae.decode, mode="max-autotune", fullgraph=True)
 
 
 
 
 
27
 
28
  MAX_SEED = np.iinfo(np.int32).max
29
  MAX_IMAGE_SIZE = 1024
30
 
31
- @spaces.GPU
32
- def infer(prompt, negative_prompt="", seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=28, progress=gr.Progress(track_tqdm=True)):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
 
34
  if randomize_seed:
35
  seed = random.randint(0, MAX_SEED)
36
-
37
- generator = torch.Generator().manual_seed(seed)
38
 
39
  image = pipe(
40
- prompt = prompt,
41
- negative_prompt = negative_prompt,
42
- width=width,
 
 
43
  height=height,
44
- guidance_scale = guidance_scale,
45
- num_inference_steps = num_inference_steps,
46
- generator = generator
47
- ).images[0]
48
 
49
  return image, seed
50
 
51
- examples = [
52
- "A photo of a lavender cat",
53
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
54
- "An astronaut riding a green horse",
55
- "A delicious ceviche cheesecake slice",
56
- ]
57
-
58
- css="""
59
- #col-container {
60
- margin: 0 auto;
61
- max-width: 520px;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  }
63
  """
64
 
65
- if torch.cuda.is_available():
66
- power_device = "GPU"
67
- else:
68
- power_device = "CPU"
 
 
 
 
 
 
 
69
 
70
- with gr.Blocks(css=css) as demo:
 
71
 
72
- with gr.Column(elem_id="col-container"):
73
- gr.Markdown(f"""
74
- # AuraFlow 0.1
75
- Demo of the [AuraFlow 0.1](https://huggingface.co/fal/AuraFlow) 6.8B parameters open source diffusion transformer model
76
- [[blog](https://blog.fal.ai/auraflow/)] [[model](https://huggingface.co/fal/AuraFlow)] [[fal](https://fal.ai/models/fal-ai/aura-flow)]
77
- """)
78
-
79
- with gr.Row():
80
 
81
- prompt = gr.Text(
82
- label="Prompt",
83
- show_label=False,
84
- max_lines=1,
85
- placeholder="Enter your prompt",
86
- container=False,
87
- )
 
 
 
 
88
 
89
- run_button = gr.Button("Run", scale=0)
90
 
91
- result = gr.Image(label="Result", show_label=False)
92
-
93
- with gr.Accordion("Advanced Settings", open=False):
94
-
95
- negative_prompt = gr.Text(
96
- label="Negative prompt",
97
- max_lines=1,
98
- placeholder="Enter a negative prompt",
99
- )
100
-
101
- seed = gr.Slider(
102
- label="Seed",
103
- minimum=0,
104
- maximum=MAX_SEED,
105
- step=1,
106
- value=0,
107
- )
108
-
109
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
110
-
111
- with gr.Row():
112
-
113
- width = gr.Slider(
114
- label="Width",
115
- minimum=256,
116
- maximum=MAX_IMAGE_SIZE,
117
- step=32,
118
- value=1024,
119
- )
120
-
121
- height = gr.Slider(
122
- label="Height",
123
- minimum=256,
124
- maximum=MAX_IMAGE_SIZE,
125
- step=32,
126
- value=1024,
127
- )
128
-
129
- with gr.Row():
130
-
131
- guidance_scale = gr.Slider(
132
- label="Guidance scale",
133
- minimum=0.0,
134
- maximum=10.0,
135
- step=0.1,
136
- value=5.0,
137
- )
138
-
139
- num_inference_steps = gr.Slider(
140
- label="Number of inference steps",
141
- minimum=1,
142
- maximum=50,
143
- step=1,
144
- value=28,
145
- )
146
-
147
- gr.Examples(
148
- examples = examples,
149
- fn = infer,
150
- inputs = [prompt],
151
- outputs = [result, seed],
152
- cache_examples="lazy"
153
- )
154
-
155
- gr.on(
156
- triggers=[run_button.click, prompt.submit, negative_prompt.submit],
157
- fn = infer,
158
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
159
- outputs = [result, seed]
160
  )
161
 
162
- demo.queue().launch()
 
1
+ import spaces
2
  import gradio as gr
 
 
 
3
  import torch
4
+ from PIL import Image
5
+ from transformers import PaliGemmaForConditionalGeneration, PaliGemmaProcessor, pipeline
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ from diffusers import AuraFlowPipeline
8
+ import re
9
+ import random
10
+ import numpy as np
11
 
12
+ # Initialize models
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ dtype = torch.float16
15
 
16
+ # AuraFlow model
 
 
 
 
 
 
17
  pipe = AuraFlowPipeline.from_pretrained(
18
+ "fal/AuraFlow",
19
  torch_dtype=torch.float16
20
+ ).to(device)
21
 
22
+ # VLM Captioner
23
+ vlm_model = PaliGemmaForConditionalGeneration.from_pretrained("gokaygokay/sd3-long-captioner-v2").to(device).eval()
24
+ vlm_processor = PaliGemmaProcessor.from_pretrained("gokaygokay/sd3-long-captioner-v2")
25
 
26
+ # Initialize Florence model
27
+ florence_model = AutoModelForCausalLM.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True).to(device).eval()
28
+ florence_processor = AutoProcessor.from_pretrained('microsoft/Florence-2-base', trust_remote_code=True)
29
+
30
+ # Prompt Enhancer
31
+ enhancer_medium = pipeline("summarization", model="gokaygokay/Lamini-fal-prompt-enchance", device=device)
32
+ enhancer_long = pipeline("summarization", model="gokaygokay/Lamini-Prompt-Enchance-Long", device=device)
33
 
34
  MAX_SEED = np.iinfo(np.int32).max
35
  MAX_IMAGE_SIZE = 1024
36
 
37
+ # Florence caption function
38
+ def florence_caption(image):
39
+ # Convert image to PIL if it's not already
40
+ if not isinstance(image, Image.Image):
41
+ image = Image.fromarray(image)
42
+
43
+ inputs = florence_processor(text="<MORE_DETAILED_CAPTION>", images=image, return_tensors="pt").to(device)
44
+ generated_ids = florence_model.generate(
45
+ input_ids=inputs["input_ids"],
46
+ pixel_values=inputs["pixel_values"],
47
+ max_new_tokens=1024,
48
+ early_stopping=False,
49
+ do_sample=False,
50
+ num_beams=3,
51
+ )
52
+ generated_text = florence_processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
53
+ parsed_answer = florence_processor.post_process_generation(
54
+ generated_text,
55
+ task="<MORE_DETAILED_CAPTION>",
56
+ image_size=(image.width, image.height)
57
+ )
58
+ return parsed_answer["<MORE_DETAILED_CAPTION>"]
59
+
60
+ # VLM Captioner function
61
+ def create_captions_rich(image):
62
+ prompt = "caption en"
63
+ model_inputs = vlm_processor(text=prompt, images=image, return_tensors="pt").to(device)
64
+ input_len = model_inputs["input_ids"].shape[-1]
65
+
66
+ with torch.inference_mode():
67
+ generation = vlm_model.generate(**model_inputs, repetition_penalty=1.10, max_new_tokens=256, do_sample=False)
68
+ generation = generation[0][input_len:]
69
+ decoded = vlm_processor.decode(generation, skip_special_tokens=True)
70
+
71
+ return modify_caption(decoded)
72
+
73
+ # Helper function for caption modification
74
+ def modify_caption(caption: str) -> str:
75
+ prefix_substrings = [
76
+ ('captured from ', ''),
77
+ ('captured at ', '')
78
+ ]
79
+ pattern = '|'.join([re.escape(opening) for opening, _ in prefix_substrings])
80
+ replacers = {opening: replacer for opening, replacer in prefix_substrings}
81
+
82
+ def replace_fn(match):
83
+ return replacers[match.group(0)]
84
+
85
+ return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
86
+
87
+ # Prompt Enhancer function
88
+ def enhance_prompt(input_prompt, model_choice):
89
+ if model_choice == "Medium":
90
+ result = enhancer_medium("Enhance the description: " + input_prompt)
91
+ enhanced_text = result[0]['summary_text']
92
+
93
+ else: # Long
94
+ result = enhancer_long("Enhance the description: " + input_prompt)
95
+ enhanced_text = result[0]['summary_text']
96
+
97
+ return enhanced_text
98
 
99
+ def generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
100
  if randomize_seed:
101
  seed = random.randint(0, MAX_SEED)
102
+
103
+ generator = torch.Generator(device=device).manual_seed(seed)
104
 
105
  image = pipe(
106
+ prompt=prompt,
107
+ negative_prompt=negative_prompt,
108
+ guidance_scale=guidance_scale,
109
+ num_inference_steps=num_inference_steps,
110
+ width=width,
111
  height=height,
112
+ generator=generator
113
+ ).images[0]
 
 
114
 
115
  return image, seed
116
 
117
+ @spaces.GPU(duration=200)
118
+ def process_workflow(image, text_prompt, vlm_model_choice, use_enhancer, model_choice, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
119
+ if image is not None:
120
+ # Convert image to PIL if it's not already
121
+ if not isinstance(image, Image.Image):
122
+ image = Image.fromarray(image)
123
+
124
+ if vlm_model_choice == "Long Captioner":
125
+ prompt = create_captions_rich(image)
126
+ else: # Florence
127
+ prompt = florence_caption(image)
128
+ else:
129
+ prompt = text_prompt
130
+
131
+ if use_enhancer:
132
+ prompt = enhance_prompt(prompt, model_choice)
133
+
134
+ generated_image, used_seed = generate_image(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps)
135
+
136
+ return generated_image, prompt, used_seed
137
+
138
+ custom_css = """
139
+ .input-group, .output-group {
140
+ border: 1px solid #e0e0e0;
141
+ border-radius: 10px;
142
+ padding: 20px;
143
+ margin-bottom: 20px;
144
+ background-color: #f9f9f9;
145
+ }
146
+ .submit-btn {
147
+ background-color: #2980b9 !important;
148
+ color: white !important;
149
+ }
150
+ .submit-btn:hover {
151
+ background-color: #3498db !important;
152
  }
153
  """
154
 
155
+ title = """<h1 align="center">AuraFlow with VLM Captioner and Prompt Enhancer</h1>
156
+ <p><center>
157
+ <a href="https://huggingface.co/fal/AuraFlow" target="_blank">[AuraFlow Model]</a>
158
+ <a href="https://huggingface.co/spaces/multimodalart/AuraFlow" target="_blank">[Original Space]</a>
159
+ <a href="https://huggingface.co/microsoft/Florence-2-base" target="_blank">[Florence-2 Model]</a>
160
+ <a href="https://huggingface.co/gokaygokay/sd3-long-captioner-v2" target="_blank">[Long Captioner Model]</a>
161
+ <a href="https://huggingface.co/gokaygokay/Lamini-Prompt-Enchance-Long" target="_blank">[Prompt Enhancer Long]</a>
162
+ <a href="https://huggingface.co/gokaygokay/Lamini-fal-prompt-enchance" target="_blank">[Prompt Enhancer Medium]</a>
163
+ <p align="center">Create long prompts from images or enhance your short prompts with prompt enhancer</p>
164
+ </center></p>
165
+ """
166
 
167
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft(primary_hue="blue", secondary_hue="gray")) as demo:
168
+ gr.HTML(title)
169
 
170
+ with gr.Row():
171
+ with gr.Column(scale=1):
172
+ with gr.Group(elem_classes="input-group"):
173
+ input_image = gr.Image(label="Input Image (VLM Captioner)")
174
+ vlm_model_choice = gr.Radio(["Florence-2", "Long Captioner"], label="VLM Model", value="Florence-2")
 
 
 
175
 
176
+ with gr.Accordion("Advanced Settings", open=False):
177
+ text_prompt = gr.Textbox(label="Text Prompt (optional, used if no image is uploaded)")
178
+ use_enhancer = gr.Checkbox(label="Use Prompt Enhancer", value=False)
179
+ model_choice = gr.Radio(["Medium", "Long"], label="Enhancer Model", value="Long")
180
+ negative_prompt = gr.Textbox(label="Negative Prompt")
181
+ seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=0)
182
+ randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
183
+ width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
184
+ height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=1024)
185
+ guidance_scale = gr.Slider(label="Guidance Scale", minimum=0.0, maximum=10.0, step=0.1, value=5.0)
186
+ num_inference_steps = gr.Slider(label="Inference Steps", minimum=1, maximum=50, step=1, value=28)
187
 
188
+ generate_btn = gr.Button("Generate Image", elem_classes="submit-btn")
189
 
190
+ with gr.Column(scale=1):
191
+ with gr.Group(elem_classes="output-group"):
192
+ output_image = gr.Image(label="Result", elem_id="gallery", show_label=False)
193
+ final_prompt = gr.Textbox(label="Final Prompt Used")
194
+ used_seed = gr.Number(label="Seed Used")
195
+
196
+ generate_btn.click(
197
+ fn=process_workflow,
198
+ inputs=[
199
+ input_image, text_prompt, vlm_model_choice, use_enhancer, model_choice,
200
+ negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps
201
+ ],
202
+ outputs=[output_image, final_prompt, used_seed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  )
204
 
205
+ demo.launch(debug=True)