Spaces:
Running
on
A10G
Running
on
A10G
Linoy Tsaban
commited on
Commit
·
8b5d4bf
1
Parent(s):
d19d91b
Update app.py
Browse files
app.py
CHANGED
@@ -28,7 +28,7 @@ blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image
|
|
28 |
def caption_image(input_image):
|
29 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
30 |
pixel_values = inputs.pixel_values
|
31 |
-
|
32 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
33 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
34 |
return generated_caption
|
@@ -38,9 +38,9 @@ def caption_image(input_image):
|
|
38 |
## DDPM INVERSION AND SAMPLING ##
|
39 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
40 |
|
41 |
-
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
42 |
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
43 |
-
|
44 |
# returns wt, zs, wts:
|
45 |
# wt - inverted latent
|
46 |
# wts - intermediate inverted latents
|
@@ -50,7 +50,7 @@ def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta
|
|
50 |
|
51 |
# vae encode image
|
52 |
with inference_mode():
|
53 |
-
|
54 |
|
55 |
# find Zs and wts - forward process
|
56 |
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
@@ -61,10 +61,10 @@ def sample(zs, wts, prompt_tar="", cfg_scale_tar=15, skip=36, eta = 1):
|
|
61 |
|
62 |
# reverse process (via Zs and wT)
|
63 |
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
64 |
-
|
65 |
# vae decode image
|
66 |
with inference_mode():
|
67 |
-
|
68 |
if x0_dec.dim()<4:
|
69 |
x0_dec = x0_dec[None,:,:,:]
|
70 |
img = image_grid(x0_dec)
|
@@ -142,7 +142,7 @@ def edit(input_image,
|
|
142 |
src_cfg_scale):
|
143 |
|
144 |
if do_inversion or randomize_seed:
|
145 |
-
x0 = load_512(input_image, device=device)
|
146 |
# invert and retrieve noise maps and latent
|
147 |
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
148 |
wts = gr.State(value=wts_tensor)
|
|
|
28 |
def caption_image(input_image):
|
29 |
inputs = blip_processor(images=input_image, return_tensors="pt").to(device, torch.float16)
|
30 |
pixel_values = inputs.pixel_values
|
31 |
+
|
32 |
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
|
33 |
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
34 |
return generated_caption
|
|
|
38 |
## DDPM INVERSION AND SAMPLING ##
|
39 |
def invert(x0, prompt_src="", num_diffusion_steps=100, cfg_scale_src = 3.5, eta = 1):
|
40 |
|
41 |
+
# inverts a real image according to Algorihm 1 in https://arxiv.org/pdf/2304.06140.pdf,
|
42 |
# based on the code in https://github.com/inbarhub/DDPM_inversion
|
43 |
+
|
44 |
# returns wt, zs, wts:
|
45 |
# wt - inverted latent
|
46 |
# wts - intermediate inverted latents
|
|
|
50 |
|
51 |
# vae encode image
|
52 |
with inference_mode():
|
53 |
+
w0 = (sd_pipe.vae.encode(x0).latent_dist.mode() * 0.18215)
|
54 |
|
55 |
# find Zs and wts - forward process
|
56 |
wt, zs, wts = inversion_forward_process(sd_pipe, w0, etas=eta, prompt=prompt_src, cfg_scale=cfg_scale_src, prog_bar=True, num_inference_steps=num_diffusion_steps)
|
|
|
61 |
|
62 |
# reverse process (via Zs and wT)
|
63 |
w0, _ = inversion_reverse_process(sd_pipe, xT=wts[skip], etas=eta, prompts=[prompt_tar], cfg_scales=[cfg_scale_tar], prog_bar=True, zs=zs[skip:])
|
64 |
+
|
65 |
# vae decode image
|
66 |
with inference_mode():
|
67 |
+
x0_dec = sd_pipe.vae.decode(1 / 0.18215 * w0).sample
|
68 |
if x0_dec.dim()<4:
|
69 |
x0_dec = x0_dec[None,:,:,:]
|
70 |
img = image_grid(x0_dec)
|
|
|
142 |
src_cfg_scale):
|
143 |
|
144 |
if do_inversion or randomize_seed:
|
145 |
+
x0 = load_512(input_image, device=device).to(torch.float16)
|
146 |
# invert and retrieve noise maps and latent
|
147 |
zs_tensor, wts_tensor = invert(x0 =x0 , prompt_src=src_prompt, num_diffusion_steps=steps, cfg_scale_src=src_cfg_scale)
|
148 |
wts = gr.State(value=wts_tensor)
|