Spaces:
Runtime error
Runtime error
Prgckwb
commited on
Commit
·
6a91e71
1
Parent(s):
3c9c988
:tada: init
Browse files- app.py +23 -7
- assets/ramen.jpg +0 -0
app.py
CHANGED
@@ -7,6 +7,7 @@ import gradio as gr
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
|
|
10 |
import torch.nn.functional as F
|
11 |
from PIL import Image
|
12 |
from pathlib import Path
|
@@ -188,7 +189,11 @@ unet.set_attn_processor(
|
|
188 |
|
189 |
|
190 |
@torch.inference_mode()
|
191 |
-
def inference(
|
|
|
|
|
|
|
|
|
192 |
progress(0, "Initializing...")
|
193 |
image = Image.open(image_path)
|
194 |
image = image.convert("RGB").resize((512, 512))
|
@@ -269,7 +274,11 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
|
|
269 |
).squeeze(0) # (77, 512, 512)
|
270 |
|
271 |
# <bos> と <eos> トークンの間に挿入されたトークンのみを取得
|
272 |
-
|
|
|
|
|
|
|
|
|
273 |
cross_attention_probs_list.append(mean_cross_attn_probs)
|
274 |
|
275 |
# list -> torch.Tensor
|
@@ -281,7 +290,10 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
|
|
281 |
image_list = []
|
282 |
# 各行ごとに画像を作成し保存
|
283 |
for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
|
284 |
-
|
|
|
|
|
|
|
285 |
|
286 |
for j in range(cross_attention_probs.shape[1]):
|
287 |
# 各クラスのアテンションマップを Min-Max 正規化 (0~1)
|
@@ -297,12 +309,15 @@ def inference(image_path: str, prompt: str, progress=gr.Progress(track_tqdm=Fals
|
|
297 |
# 各行ごとの画像を保存
|
298 |
out_dir = Path("output")
|
299 |
out_dir.mkdir(exist_ok=True)
|
300 |
-
|
|
|
|
|
301 |
plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
|
302 |
plt.close(fig)
|
303 |
|
304 |
# 保存した画像をPILで読み込んでリストに追加
|
305 |
image_list.append(Image.open(filepath))
|
|
|
306 |
return image_list
|
307 |
|
308 |
|
@@ -333,13 +348,14 @@ if __name__ == '__main__':
|
|
333 |
fn=inference,
|
334 |
inputs=[
|
335 |
gr.Image(type="filepath", label="Input", width=512, height=512),
|
336 |
-
gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog...")
|
|
|
337 |
],
|
338 |
outputs=ca_output,
|
339 |
cache_examples=True,
|
340 |
examples=[
|
341 |
-
["assets/aeroplane.png", "plane background"],
|
342 |
-
["assets/dogcat.png", "a photo of dog
|
343 |
]
|
344 |
)
|
345 |
|
|
|
7 |
import matplotlib.pyplot as plt
|
8 |
import numpy as np
|
9 |
import torch
|
10 |
+
import uuid
|
11 |
import torch.nn.functional as F
|
12 |
from PIL import Image
|
13 |
from pathlib import Path
|
|
|
189 |
|
190 |
|
191 |
@torch.inference_mode()
|
192 |
+
def inference(
|
193 |
+
image_path: str,
|
194 |
+
prompt: str,
|
195 |
+
has_include_special_tokens: bool = False,
|
196 |
+
progress=gr.Progress(track_tqdm=False)):
|
197 |
progress(0, "Initializing...")
|
198 |
image = Image.open(image_path)
|
199 |
image = image.convert("RGB").resize((512, 512))
|
|
|
274 |
).squeeze(0) # (77, 512, 512)
|
275 |
|
276 |
# <bos> と <eos> トークンの間に挿入されたトークンのみを取得
|
277 |
+
if has_include_special_tokens:
|
278 |
+
mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
|
279 |
+
else:
|
280 |
+
mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...]
|
281 |
+
|
282 |
cross_attention_probs_list.append(mean_cross_attn_probs)
|
283 |
|
284 |
# list -> torch.Tensor
|
|
|
290 |
image_list = []
|
291 |
# 各行ごとに画像を作成し保存
|
292 |
for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
|
293 |
+
if has_include_special_tokens:
|
294 |
+
fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4))
|
295 |
+
else:
|
296 |
+
fig, ax = plt.subplots(1, n_cond_tokens - 2, figsize=(16, 4))
|
297 |
|
298 |
for j in range(cross_attention_probs.shape[1]):
|
299 |
# 各クラスのアテンションマップを Min-Max 正規化 (0~1)
|
|
|
309 |
# 各行ごとの画像を保存
|
310 |
out_dir = Path("output")
|
311 |
out_dir.mkdir(exist_ok=True)
|
312 |
+
# 一意なランダムファイル名を生成
|
313 |
+
unique_filename = str(uuid.uuid4())
|
314 |
+
filepath = out_dir / f"{unique_filename}.png"
|
315 |
plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
|
316 |
plt.close(fig)
|
317 |
|
318 |
# 保存した画像をPILで読み込んでリストに追加
|
319 |
image_list.append(Image.open(filepath))
|
320 |
+
attn_processor.reset_attention_stores()
|
321 |
return image_list
|
322 |
|
323 |
|
|
|
348 |
fn=inference,
|
349 |
inputs=[
|
350 |
gr.Image(type="filepath", label="Input", width=512, height=512),
|
351 |
+
gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."),
|
352 |
+
gr.Checkbox(label="Include Special Tokens", value=False),
|
353 |
],
|
354 |
outputs=ca_output,
|
355 |
cache_examples=True,
|
356 |
examples=[
|
357 |
+
["assets/aeroplane.png", "plane background", False],
|
358 |
+
["assets/dogcat.png", "a photo of dog", False],
|
359 |
]
|
360 |
)
|
361 |
|
assets/ramen.jpg
ADDED
![]() |