Spaces:
Running
on
Zero
Running
on
Zero
new update
Browse files- src/attentionhacked_tryon.py +3 -0
- src/tryon_pipeline.py +105 -120
- src/unet_block_hacked_garmnet.py +3 -3
src/attentionhacked_tryon.py
CHANGED
@@ -331,6 +331,7 @@ class BasicTransformerBlock(nn.Module):
|
|
331 |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
332 |
|
333 |
|
|
|
334 |
modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
|
335 |
curr_garment_feat_idx +=1
|
336 |
attn_output = self.attn1(
|
@@ -345,6 +346,8 @@ class BasicTransformerBlock(nn.Module):
|
|
345 |
elif self.use_ada_layer_norm_single:
|
346 |
attn_output = gate_msa * attn_output
|
347 |
|
|
|
|
|
348 |
hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
|
349 |
|
350 |
|
|
|
331 |
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
|
332 |
|
333 |
|
334 |
+
#type2
|
335 |
modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
|
336 |
curr_garment_feat_idx +=1
|
337 |
attn_output = self.attn1(
|
|
|
346 |
elif self.use_ada_layer_norm_single:
|
347 |
attn_output = gate_msa * attn_output
|
348 |
|
349 |
+
|
350 |
+
#type2
|
351 |
hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
|
352 |
|
353 |
|
src/tryon_pipeline.py
CHANGED
@@ -56,11 +56,8 @@ from diffusers.utils import (
|
|
56 |
from diffusers.utils.torch_utils import randn_tensor
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
59 |
-
from util.pipeline import torch_gc
|
60 |
|
61 |
|
62 |
-
USE_PEFT_BACKEND = True
|
63 |
-
|
64 |
if is_torch_xla_available():
|
65 |
import torch_xla.core.xla_model as xm
|
66 |
|
@@ -567,7 +564,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
567 |
the output of the pre-final layer will be used for computing the prompt embeddings.
|
568 |
"""
|
569 |
device = device or self._execution_device
|
570 |
-
|
571 |
# set lora scale so that monkey patched LoRA
|
572 |
# function of text encoder can correctly access it
|
573 |
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
@@ -1297,8 +1294,6 @@ class StableDiffusionXLInpaintPipeline(
|
|
1297 |
pooled_prompt_embeds_c=None,
|
1298 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1299 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
1300 |
-
dtype: torch.dtype = torch.float32,
|
1301 |
-
device: torch.device = torch.device('cuda'),
|
1302 |
**kwargs,
|
1303 |
):
|
1304 |
r"""
|
@@ -1528,7 +1523,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
1528 |
else:
|
1529 |
batch_size = prompt_embeds.shape[0]
|
1530 |
|
1531 |
-
|
1532 |
|
1533 |
# 3. Encode input prompt
|
1534 |
text_encoder_lora_scale = (
|
@@ -1555,10 +1550,6 @@ class StableDiffusionXLInpaintPipeline(
|
|
1555 |
lora_scale=text_encoder_lora_scale,
|
1556 |
clip_skip=self.clip_skip,
|
1557 |
)
|
1558 |
-
#move encoders to cpu for free memory
|
1559 |
-
self.text_encoder.to('cpu')
|
1560 |
-
self.text_encoder_2.to('cpu')
|
1561 |
-
torch_gc()
|
1562 |
|
1563 |
# 4. set timesteps
|
1564 |
def denoising_value_valid(dnv):
|
@@ -1618,7 +1609,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
1618 |
num_channels_latents,
|
1619 |
height,
|
1620 |
width,
|
1621 |
-
|
1622 |
device,
|
1623 |
generator,
|
1624 |
latents,
|
@@ -1642,12 +1633,12 @@ class StableDiffusionXLInpaintPipeline(
|
|
1642 |
batch_size * num_images_per_prompt,
|
1643 |
height,
|
1644 |
width,
|
1645 |
-
dtype,
|
1646 |
device,
|
1647 |
generator,
|
1648 |
self.do_classifier_free_guidance,
|
1649 |
)
|
1650 |
-
pose_img = pose_img.to(device=device, dtype=dtype)
|
1651 |
|
1652 |
pose_img = self.vae.encode(pose_img).latent_dist.sample()
|
1653 |
pose_img = pose_img * self.vae.config.scaling_factor
|
@@ -1728,12 +1719,9 @@ class StableDiffusionXLInpaintPipeline(
|
|
1728 |
ip_adapter_image, device, batch_size * num_images_per_prompt
|
1729 |
)
|
1730 |
|
1731 |
-
#put unet on same device
|
1732 |
-
self.unet.to(device)
|
1733 |
-
#image_embeds = image_embeds.to(dtype)
|
1734 |
#project outside for loop
|
1735 |
-
|
1736 |
-
|
1737 |
|
1738 |
# 11. Denoising loop
|
1739 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
@@ -1770,111 +1758,109 @@ class StableDiffusionXLInpaintPipeline(
|
|
1770 |
|
1771 |
|
1772 |
self._num_timesteps = len(timesteps)
|
1773 |
-
with
|
1774 |
-
|
1775 |
-
|
1776 |
-
|
1777 |
-
|
1778 |
-
|
1779 |
-
|
1780 |
-
|
1781 |
-
|
1782 |
-
|
1783 |
-
|
1784 |
-
|
1785 |
-
|
1786 |
-
|
1787 |
-
|
1788 |
-
|
1789 |
-
|
1790 |
-
|
1791 |
-
|
1792 |
-
|
1793 |
-
|
1794 |
-
added_cond_kwargs
|
1795 |
-
|
1796 |
-
|
1797 |
-
|
1798 |
-
|
1799 |
-
|
1800 |
-
|
1801 |
-
|
1802 |
-
|
1803 |
-
|
1804 |
-
|
1805 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1806 |
if self.do_classifier_free_guidance:
|
1807 |
-
|
1808 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1809 |
|
1810 |
-
|
1811 |
-
latent_model_input,
|
1812 |
-
t,
|
1813 |
-
encoder_hidden_states=prompt_embeds,
|
1814 |
-
timestep_cond=timestep_cond,
|
1815 |
-
cross_attention_kwargs=self.cross_attention_kwargs,
|
1816 |
-
added_cond_kwargs=added_cond_kwargs,
|
1817 |
-
return_dict=False,
|
1818 |
-
garment_features=reference_features,
|
1819 |
-
)[0]
|
1820 |
-
# noise_pred = self.unet(latent_model_input, t,
|
1821 |
-
# prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,down_block_additional_attn=down ).sample
|
1822 |
|
|
|
|
|
|
|
|
|
|
|
1823 |
|
1824 |
-
|
1825 |
-
|
1826 |
-
|
1827 |
-
|
1828 |
-
|
1829 |
-
|
1830 |
-
|
1831 |
-
|
1832 |
-
|
1833 |
-
|
1834 |
-
|
1835 |
-
|
1836 |
-
if num_channels_unet == 4:
|
1837 |
-
init_latents_proper = image_latents
|
1838 |
-
if self.do_classifier_free_guidance:
|
1839 |
-
init_mask, _ = mask.chunk(2)
|
1840 |
-
else:
|
1841 |
-
init_mask = mask
|
1842 |
-
|
1843 |
-
if i < len(timesteps) - 1:
|
1844 |
-
noise_timestep = timesteps[i + 1]
|
1845 |
-
init_latents_proper = self.scheduler.add_noise(
|
1846 |
-
init_latents_proper, noise, torch.tensor([noise_timestep])
|
1847 |
-
)
|
1848 |
-
|
1849 |
-
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
1850 |
-
|
1851 |
-
if callback_on_step_end is not None:
|
1852 |
-
callback_kwargs = {}
|
1853 |
-
for k in callback_on_step_end_tensor_inputs:
|
1854 |
-
callback_kwargs[k] = locals()[k]
|
1855 |
-
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1856 |
-
|
1857 |
-
latents = callback_outputs.pop("latents", latents)
|
1858 |
-
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1859 |
-
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1860 |
-
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1861 |
-
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1862 |
-
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1863 |
-
)
|
1864 |
-
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1865 |
-
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
1866 |
-
mask = callback_outputs.pop("mask", mask)
|
1867 |
-
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
1868 |
|
1869 |
-
|
1870 |
-
|
1871 |
-
|
1872 |
-
|
1873 |
-
|
1874 |
-
|
1875 |
|
1876 |
-
|
1877 |
-
|
1878 |
|
1879 |
if not output_type == "latent":
|
1880 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
@@ -1899,8 +1885,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
1899 |
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
1900 |
|
1901 |
# Offload all models
|
1902 |
-
|
1903 |
-
self.maybe_free_model_hooks()
|
1904 |
|
1905 |
# if not return_dict:
|
1906 |
return (image,)
|
|
|
56 |
from diffusers.utils.torch_utils import randn_tensor
|
57 |
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
58 |
|
|
|
59 |
|
60 |
|
|
|
|
|
61 |
if is_torch_xla_available():
|
62 |
import torch_xla.core.xla_model as xm
|
63 |
|
|
|
564 |
the output of the pre-final layer will be used for computing the prompt embeddings.
|
565 |
"""
|
566 |
device = device or self._execution_device
|
567 |
+
|
568 |
# set lora scale so that monkey patched LoRA
|
569 |
# function of text encoder can correctly access it
|
570 |
if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
|
|
|
1294 |
pooled_prompt_embeds_c=None,
|
1295 |
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
1296 |
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
|
|
|
|
1297 |
**kwargs,
|
1298 |
):
|
1299 |
r"""
|
|
|
1523 |
else:
|
1524 |
batch_size = prompt_embeds.shape[0]
|
1525 |
|
1526 |
+
device = self._execution_device
|
1527 |
|
1528 |
# 3. Encode input prompt
|
1529 |
text_encoder_lora_scale = (
|
|
|
1550 |
lora_scale=text_encoder_lora_scale,
|
1551 |
clip_skip=self.clip_skip,
|
1552 |
)
|
|
|
|
|
|
|
|
|
1553 |
|
1554 |
# 4. set timesteps
|
1555 |
def denoising_value_valid(dnv):
|
|
|
1609 |
num_channels_latents,
|
1610 |
height,
|
1611 |
width,
|
1612 |
+
prompt_embeds.dtype,
|
1613 |
device,
|
1614 |
generator,
|
1615 |
latents,
|
|
|
1633 |
batch_size * num_images_per_prompt,
|
1634 |
height,
|
1635 |
width,
|
1636 |
+
prompt_embeds.dtype,
|
1637 |
device,
|
1638 |
generator,
|
1639 |
self.do_classifier_free_guidance,
|
1640 |
)
|
1641 |
+
pose_img = pose_img.to(device=device, dtype=prompt_embeds.dtype)
|
1642 |
|
1643 |
pose_img = self.vae.encode(pose_img).latent_dist.sample()
|
1644 |
pose_img = pose_img * self.vae.config.scaling_factor
|
|
|
1719 |
ip_adapter_image, device, batch_size * num_images_per_prompt
|
1720 |
)
|
1721 |
|
|
|
|
|
|
|
1722 |
#project outside for loop
|
1723 |
+
image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
|
1724 |
+
|
1725 |
|
1726 |
# 11. Denoising loop
|
1727 |
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
|
|
|
1758 |
|
1759 |
|
1760 |
self._num_timesteps = len(timesteps)
|
1761 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
1762 |
+
for i, t in enumerate(timesteps):
|
1763 |
+
if self.interrupt:
|
1764 |
+
continue
|
1765 |
+
# expand the latents if we are doing classifier free guidance
|
1766 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
1767 |
+
|
1768 |
+
# concat latents, mask, masked_image_latents in the channel dimension
|
1769 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1770 |
+
|
1771 |
+
|
1772 |
+
# bsz = mask.shape[0]
|
1773 |
+
if num_channels_unet == 13:
|
1774 |
+
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,pose_img], dim=1)
|
1775 |
+
|
1776 |
+
# if num_channels_unet == 9:
|
1777 |
+
# latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
1778 |
+
|
1779 |
+
# predict the noise residual
|
1780 |
+
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
1781 |
+
if ip_adapter_image is not None:
|
1782 |
+
added_cond_kwargs["image_embeds"] = image_embeds
|
1783 |
+
# down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
|
1784 |
+
down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
|
1785 |
+
# print(type(reference_features))
|
1786 |
+
# print(reference_features)
|
1787 |
+
reference_features = list(reference_features)
|
1788 |
+
# print(len(reference_features))
|
1789 |
+
# for elem in reference_features:
|
1790 |
+
# print(elem.shape)
|
1791 |
+
# exit(1)
|
1792 |
+
if self.do_classifier_free_guidance:
|
1793 |
+
reference_features = [torch.cat([torch.zeros_like(d), d]) for d in reference_features]
|
1794 |
+
|
1795 |
+
|
1796 |
+
noise_pred = self.unet(
|
1797 |
+
latent_model_input,
|
1798 |
+
t,
|
1799 |
+
encoder_hidden_states=prompt_embeds,
|
1800 |
+
timestep_cond=timestep_cond,
|
1801 |
+
cross_attention_kwargs=self.cross_attention_kwargs,
|
1802 |
+
added_cond_kwargs=added_cond_kwargs,
|
1803 |
+
return_dict=False,
|
1804 |
+
garment_features=reference_features,
|
1805 |
+
)[0]
|
1806 |
+
# noise_pred = self.unet(latent_model_input, t,
|
1807 |
+
# prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,down_block_additional_attn=down ).sample
|
1808 |
+
|
1809 |
+
|
1810 |
+
# perform guidance
|
1811 |
+
if self.do_classifier_free_guidance:
|
1812 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
1813 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
|
1814 |
+
|
1815 |
+
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
|
1816 |
+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
|
1817 |
+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
|
1818 |
+
|
1819 |
+
# compute the previous noisy sample x_t -> x_t-1
|
1820 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
1821 |
+
|
1822 |
+
if num_channels_unet == 4:
|
1823 |
+
init_latents_proper = image_latents
|
1824 |
if self.do_classifier_free_guidance:
|
1825 |
+
init_mask, _ = mask.chunk(2)
|
1826 |
+
else:
|
1827 |
+
init_mask = mask
|
1828 |
+
|
1829 |
+
if i < len(timesteps) - 1:
|
1830 |
+
noise_timestep = timesteps[i + 1]
|
1831 |
+
init_latents_proper = self.scheduler.add_noise(
|
1832 |
+
init_latents_proper, noise, torch.tensor([noise_timestep])
|
1833 |
+
)
|
1834 |
|
1835 |
+
latents = (1 - init_mask) * init_latents_proper + init_mask * latents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1836 |
|
1837 |
+
if callback_on_step_end is not None:
|
1838 |
+
callback_kwargs = {}
|
1839 |
+
for k in callback_on_step_end_tensor_inputs:
|
1840 |
+
callback_kwargs[k] = locals()[k]
|
1841 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
1842 |
|
1843 |
+
latents = callback_outputs.pop("latents", latents)
|
1844 |
+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
|
1845 |
+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
|
1846 |
+
add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
|
1847 |
+
negative_pooled_prompt_embeds = callback_outputs.pop(
|
1848 |
+
"negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
|
1849 |
+
)
|
1850 |
+
add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
|
1851 |
+
add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
|
1852 |
+
mask = callback_outputs.pop("mask", mask)
|
1853 |
+
masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1854 |
|
1855 |
+
# call the callback, if provided
|
1856 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1857 |
+
progress_bar.update()
|
1858 |
+
if callback is not None and i % callback_steps == 0:
|
1859 |
+
step_idx = i // getattr(self.scheduler, "order", 1)
|
1860 |
+
callback(step_idx, t, latents)
|
1861 |
|
1862 |
+
if XLA_AVAILABLE:
|
1863 |
+
xm.mark_step()
|
1864 |
|
1865 |
if not output_type == "latent":
|
1866 |
# make sure the VAE is in float32 mode, as it overflows in float16
|
|
|
1885 |
image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
|
1886 |
|
1887 |
# Offload all models
|
1888 |
+
self.maybe_free_model_hooks()
|
|
|
1889 |
|
1890 |
# if not return_dict:
|
1891 |
return (image,)
|
src/unet_block_hacked_garmnet.py
CHANGED
@@ -17,7 +17,7 @@ import numpy as np
|
|
17 |
import torch
|
18 |
import torch.nn.functional as F
|
19 |
from torch import nn
|
20 |
-
|
21 |
from diffusers.utils import is_torch_version, logging
|
22 |
from diffusers.utils.torch_utils import apply_freeu
|
23 |
from diffusers.models.activations import get_activation
|
@@ -1252,7 +1252,7 @@ class DownBlock2D(nn.Module):
|
|
1252 |
def create_custom_forward(module):
|
1253 |
def custom_forward(*inputs):
|
1254 |
return module(*inputs)
|
1255 |
-
|
1256 |
return custom_forward
|
1257 |
|
1258 |
if is_torch_version(">=", "1.11.0"):
|
@@ -1263,7 +1263,7 @@ class DownBlock2D(nn.Module):
|
|
1263 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
1264 |
create_custom_forward(resnet), hidden_states, temb
|
1265 |
)
|
1266 |
-
else:
|
1267 |
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1268 |
|
1269 |
output_states = output_states + (hidden_states,)
|
|
|
17 |
import torch
|
18 |
import torch.nn.functional as F
|
19 |
from torch import nn
|
20 |
+
|
21 |
from diffusers.utils import is_torch_version, logging
|
22 |
from diffusers.utils.torch_utils import apply_freeu
|
23 |
from diffusers.models.activations import get_activation
|
|
|
1252 |
def create_custom_forward(module):
|
1253 |
def custom_forward(*inputs):
|
1254 |
return module(*inputs)
|
1255 |
+
|
1256 |
return custom_forward
|
1257 |
|
1258 |
if is_torch_version(">=", "1.11.0"):
|
|
|
1263 |
hidden_states = torch.utils.checkpoint.checkpoint(
|
1264 |
create_custom_forward(resnet), hidden_states, temb
|
1265 |
)
|
1266 |
+
else:
|
1267 |
hidden_states = resnet(hidden_states, temb, scale=scale)
|
1268 |
|
1269 |
output_states = output_states + (hidden_states,)
|