Roopansh commited on
Commit
a6ca64b
β€’
1 Parent(s): fe331b9

new update

Browse files
src/attentionhacked_tryon.py CHANGED
@@ -331,6 +331,7 @@ class BasicTransformerBlock(nn.Module):
331
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
332
 
333
 
 
334
  modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
335
  curr_garment_feat_idx +=1
336
  attn_output = self.attn1(
@@ -345,6 +346,8 @@ class BasicTransformerBlock(nn.Module):
345
  elif self.use_ada_layer_norm_single:
346
  attn_output = gate_msa * attn_output
347
 
 
 
348
  hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
349
 
350
 
 
331
  gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
332
 
333
 
334
+ #type2
335
  modify_norm_hidden_states = torch.cat([norm_hidden_states,garment_features[curr_garment_feat_idx]], dim=1)
336
  curr_garment_feat_idx +=1
337
  attn_output = self.attn1(
 
346
  elif self.use_ada_layer_norm_single:
347
  attn_output = gate_msa * attn_output
348
 
349
+
350
+ #type2
351
  hidden_states = attn_output[:,:hidden_states.shape[-2],:] + hidden_states
352
 
353
 
src/tryon_pipeline.py CHANGED
@@ -56,11 +56,8 @@ from diffusers.utils import (
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
59
- from util.pipeline import torch_gc
60
 
61
 
62
- USE_PEFT_BACKEND = True
63
-
64
  if is_torch_xla_available():
65
  import torch_xla.core.xla_model as xm
66
 
@@ -567,7 +564,7 @@ class StableDiffusionXLInpaintPipeline(
567
  the output of the pre-final layer will be used for computing the prompt embeddings.
568
  """
569
  device = device or self._execution_device
570
-
571
  # set lora scale so that monkey patched LoRA
572
  # function of text encoder can correctly access it
573
  if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
@@ -1297,8 +1294,6 @@ class StableDiffusionXLInpaintPipeline(
1297
  pooled_prompt_embeds_c=None,
1298
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1299
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
1300
- dtype: torch.dtype = torch.float32,
1301
- device: torch.device = torch.device('cuda'),
1302
  **kwargs,
1303
  ):
1304
  r"""
@@ -1528,7 +1523,7 @@ class StableDiffusionXLInpaintPipeline(
1528
  else:
1529
  batch_size = prompt_embeds.shape[0]
1530
 
1531
- #device = self._execution_device
1532
 
1533
  # 3. Encode input prompt
1534
  text_encoder_lora_scale = (
@@ -1555,10 +1550,6 @@ class StableDiffusionXLInpaintPipeline(
1555
  lora_scale=text_encoder_lora_scale,
1556
  clip_skip=self.clip_skip,
1557
  )
1558
- #move encoders to cpu for free memory
1559
- self.text_encoder.to('cpu')
1560
- self.text_encoder_2.to('cpu')
1561
- torch_gc()
1562
 
1563
  # 4. set timesteps
1564
  def denoising_value_valid(dnv):
@@ -1618,7 +1609,7 @@ class StableDiffusionXLInpaintPipeline(
1618
  num_channels_latents,
1619
  height,
1620
  width,
1621
- dtype, #prompt_embeds.dtype,
1622
  device,
1623
  generator,
1624
  latents,
@@ -1642,12 +1633,12 @@ class StableDiffusionXLInpaintPipeline(
1642
  batch_size * num_images_per_prompt,
1643
  height,
1644
  width,
1645
- dtype,
1646
  device,
1647
  generator,
1648
  self.do_classifier_free_guidance,
1649
  )
1650
- pose_img = pose_img.to(device=device, dtype=dtype)
1651
 
1652
  pose_img = self.vae.encode(pose_img).latent_dist.sample()
1653
  pose_img = pose_img * self.vae.config.scaling_factor
@@ -1728,12 +1719,9 @@ class StableDiffusionXLInpaintPipeline(
1728
  ip_adapter_image, device, batch_size * num_images_per_prompt
1729
  )
1730
 
1731
- #put unet on same device
1732
- self.unet.to(device)
1733
- #image_embeds = image_embeds.to(dtype)
1734
  #project outside for loop
1735
- with torch.cuda.amp.autocast(dtype=dtype, enabled=True):
1736
- image_embeds = self.unet.encoder_hid_proj(image_embeds).to(dtype)
1737
 
1738
  # 11. Denoising loop
1739
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1770,111 +1758,109 @@ class StableDiffusionXLInpaintPipeline(
1770
 
1771
 
1772
  self._num_timesteps = len(timesteps)
1773
- with torch.cuda.amp.autocast(dtype=dtype, enabled=True):
1774
- with self.progress_bar(total=num_inference_steps) as progress_bar:
1775
- for i, t in enumerate(timesteps):
1776
- t.to(dtype)
1777
- if self.interrupt:
1778
- continue
1779
- # expand the latents if we are doing classifier free guidance
1780
- latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1781
-
1782
- # concat latents, mask, masked_image_latents in the channel dimension
1783
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1784
-
1785
-
1786
- # bsz = mask.shape[0]
1787
- if num_channels_unet == 13:
1788
- latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,pose_img], dim=1)
1789
-
1790
- # if num_channels_unet == 9:
1791
- # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1792
-
1793
- # predict the noise residual
1794
- added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1795
- if ip_adapter_image is not None:
1796
- added_cond_kwargs["image_embeds"] = image_embeds
1797
- # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1798
- down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1799
- # print(type(reference_features))
1800
- # print(reference_features)
1801
- reference_features = list(reference_features)
1802
- # print(len(reference_features))
1803
- # for elem in reference_features:
1804
- # print(elem.shape)
1805
- # exit(1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1806
  if self.do_classifier_free_guidance:
1807
- reference_features = [torch.cat([torch.zeros_like(d), d]) for d in reference_features]
1808
-
 
 
 
 
 
 
 
1809
 
1810
- noise_pred = self.unet(
1811
- latent_model_input,
1812
- t,
1813
- encoder_hidden_states=prompt_embeds,
1814
- timestep_cond=timestep_cond,
1815
- cross_attention_kwargs=self.cross_attention_kwargs,
1816
- added_cond_kwargs=added_cond_kwargs,
1817
- return_dict=False,
1818
- garment_features=reference_features,
1819
- )[0]
1820
- # noise_pred = self.unet(latent_model_input, t,
1821
- # prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,down_block_additional_attn=down ).sample
1822
 
 
 
 
 
 
1823
 
1824
- # perform guidance
1825
- if self.do_classifier_free_guidance:
1826
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1827
- noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1828
-
1829
- if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1830
- # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1831
- noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1832
-
1833
- # compute the previous noisy sample x_t -> x_t-1
1834
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1835
-
1836
- if num_channels_unet == 4:
1837
- init_latents_proper = image_latents
1838
- if self.do_classifier_free_guidance:
1839
- init_mask, _ = mask.chunk(2)
1840
- else:
1841
- init_mask = mask
1842
-
1843
- if i < len(timesteps) - 1:
1844
- noise_timestep = timesteps[i + 1]
1845
- init_latents_proper = self.scheduler.add_noise(
1846
- init_latents_proper, noise, torch.tensor([noise_timestep])
1847
- )
1848
-
1849
- latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1850
-
1851
- if callback_on_step_end is not None:
1852
- callback_kwargs = {}
1853
- for k in callback_on_step_end_tensor_inputs:
1854
- callback_kwargs[k] = locals()[k]
1855
- callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1856
-
1857
- latents = callback_outputs.pop("latents", latents)
1858
- prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1859
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1860
- add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1861
- negative_pooled_prompt_embeds = callback_outputs.pop(
1862
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1863
- )
1864
- add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1865
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1866
- mask = callback_outputs.pop("mask", mask)
1867
- masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1868
 
1869
- # call the callback, if provided
1870
- if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1871
- progress_bar.update()
1872
- if callback is not None and i % callback_steps == 0:
1873
- step_idx = i // getattr(self.scheduler, "order", 1)
1874
- callback(step_idx, t, latents)
1875
 
1876
- if XLA_AVAILABLE:
1877
- xm.mark_step()
1878
 
1879
  if not output_type == "latent":
1880
  # make sure the VAE is in float32 mode, as it overflows in float16
@@ -1899,8 +1885,7 @@ class StableDiffusionXLInpaintPipeline(
1899
  image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1900
 
1901
  # Offload all models
1902
- if device.type=='cpu':
1903
- self.maybe_free_model_hooks()
1904
 
1905
  # if not return_dict:
1906
  return (image,)
 
56
  from diffusers.utils.torch_utils import randn_tensor
57
  from diffusers.pipelines.pipeline_utils import DiffusionPipeline
58
 
 
59
 
60
 
 
 
61
  if is_torch_xla_available():
62
  import torch_xla.core.xla_model as xm
63
 
 
564
  the output of the pre-final layer will be used for computing the prompt embeddings.
565
  """
566
  device = device or self._execution_device
567
+
568
  # set lora scale so that monkey patched LoRA
569
  # function of text encoder can correctly access it
570
  if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin):
 
1294
  pooled_prompt_embeds_c=None,
1295
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
1296
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
 
 
1297
  **kwargs,
1298
  ):
1299
  r"""
 
1523
  else:
1524
  batch_size = prompt_embeds.shape[0]
1525
 
1526
+ device = self._execution_device
1527
 
1528
  # 3. Encode input prompt
1529
  text_encoder_lora_scale = (
 
1550
  lora_scale=text_encoder_lora_scale,
1551
  clip_skip=self.clip_skip,
1552
  )
 
 
 
 
1553
 
1554
  # 4. set timesteps
1555
  def denoising_value_valid(dnv):
 
1609
  num_channels_latents,
1610
  height,
1611
  width,
1612
+ prompt_embeds.dtype,
1613
  device,
1614
  generator,
1615
  latents,
 
1633
  batch_size * num_images_per_prompt,
1634
  height,
1635
  width,
1636
+ prompt_embeds.dtype,
1637
  device,
1638
  generator,
1639
  self.do_classifier_free_guidance,
1640
  )
1641
+ pose_img = pose_img.to(device=device, dtype=prompt_embeds.dtype)
1642
 
1643
  pose_img = self.vae.encode(pose_img).latent_dist.sample()
1644
  pose_img = pose_img * self.vae.config.scaling_factor
 
1719
  ip_adapter_image, device, batch_size * num_images_per_prompt
1720
  )
1721
 
 
 
 
1722
  #project outside for loop
1723
+ image_embeds = self.unet.encoder_hid_proj(image_embeds).to(prompt_embeds.dtype)
1724
+
1725
 
1726
  # 11. Denoising loop
1727
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
 
1758
 
1759
 
1760
  self._num_timesteps = len(timesteps)
1761
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1762
+ for i, t in enumerate(timesteps):
1763
+ if self.interrupt:
1764
+ continue
1765
+ # expand the latents if we are doing classifier free guidance
1766
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1767
+
1768
+ # concat latents, mask, masked_image_latents in the channel dimension
1769
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1770
+
1771
+
1772
+ # bsz = mask.shape[0]
1773
+ if num_channels_unet == 13:
1774
+ latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents,pose_img], dim=1)
1775
+
1776
+ # if num_channels_unet == 9:
1777
+ # latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
1778
+
1779
+ # predict the noise residual
1780
+ added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
1781
+ if ip_adapter_image is not None:
1782
+ added_cond_kwargs["image_embeds"] = image_embeds
1783
+ # down,reference_features = self.UNet_Encoder(cloth,t, text_embeds_cloth,added_cond_kwargs= {"text_embeds": pooled_prompt_embeds_c, "time_ids": add_time_ids},return_dict=False)
1784
+ down,reference_features = self.unet_encoder(cloth,t, text_embeds_cloth,return_dict=False)
1785
+ # print(type(reference_features))
1786
+ # print(reference_features)
1787
+ reference_features = list(reference_features)
1788
+ # print(len(reference_features))
1789
+ # for elem in reference_features:
1790
+ # print(elem.shape)
1791
+ # exit(1)
1792
+ if self.do_classifier_free_guidance:
1793
+ reference_features = [torch.cat([torch.zeros_like(d), d]) for d in reference_features]
1794
+
1795
+
1796
+ noise_pred = self.unet(
1797
+ latent_model_input,
1798
+ t,
1799
+ encoder_hidden_states=prompt_embeds,
1800
+ timestep_cond=timestep_cond,
1801
+ cross_attention_kwargs=self.cross_attention_kwargs,
1802
+ added_cond_kwargs=added_cond_kwargs,
1803
+ return_dict=False,
1804
+ garment_features=reference_features,
1805
+ )[0]
1806
+ # noise_pred = self.unet(latent_model_input, t,
1807
+ # prompt_embeds,timestep_cond=timestep_cond,cross_attention_kwargs=self.cross_attention_kwargs,added_cond_kwargs=added_cond_kwargs,down_block_additional_attn=down ).sample
1808
+
1809
+
1810
+ # perform guidance
1811
+ if self.do_classifier_free_guidance:
1812
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1813
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1814
+
1815
+ if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
1816
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
1817
+ noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
1818
+
1819
+ # compute the previous noisy sample x_t -> x_t-1
1820
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
1821
+
1822
+ if num_channels_unet == 4:
1823
+ init_latents_proper = image_latents
1824
  if self.do_classifier_free_guidance:
1825
+ init_mask, _ = mask.chunk(2)
1826
+ else:
1827
+ init_mask = mask
1828
+
1829
+ if i < len(timesteps) - 1:
1830
+ noise_timestep = timesteps[i + 1]
1831
+ init_latents_proper = self.scheduler.add_noise(
1832
+ init_latents_proper, noise, torch.tensor([noise_timestep])
1833
+ )
1834
 
1835
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
 
 
 
 
 
 
 
 
 
 
 
1836
 
1837
+ if callback_on_step_end is not None:
1838
+ callback_kwargs = {}
1839
+ for k in callback_on_step_end_tensor_inputs:
1840
+ callback_kwargs[k] = locals()[k]
1841
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1842
 
1843
+ latents = callback_outputs.pop("latents", latents)
1844
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1845
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1846
+ add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1847
+ negative_pooled_prompt_embeds = callback_outputs.pop(
1848
+ "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1849
+ )
1850
+ add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1851
+ add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1852
+ mask = callback_outputs.pop("mask", mask)
1853
+ masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1854
 
1855
+ # call the callback, if provided
1856
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1857
+ progress_bar.update()
1858
+ if callback is not None and i % callback_steps == 0:
1859
+ step_idx = i // getattr(self.scheduler, "order", 1)
1860
+ callback(step_idx, t, latents)
1861
 
1862
+ if XLA_AVAILABLE:
1863
+ xm.mark_step()
1864
 
1865
  if not output_type == "latent":
1866
  # make sure the VAE is in float32 mode, as it overflows in float16
 
1885
  image = [self.image_processor.apply_overlay(mask_image, original_image, i, crops_coords) for i in image]
1886
 
1887
  # Offload all models
1888
+ self.maybe_free_model_hooks()
 
1889
 
1890
  # if not return_dict:
1891
  return (image,)
src/unet_block_hacked_garmnet.py CHANGED
@@ -17,7 +17,7 @@ import numpy as np
17
  import torch
18
  import torch.nn.functional as F
19
  from torch import nn
20
- import bitsandbytes as bnb
21
  from diffusers.utils import is_torch_version, logging
22
  from diffusers.utils.torch_utils import apply_freeu
23
  from diffusers.models.activations import get_activation
@@ -1252,7 +1252,7 @@ class DownBlock2D(nn.Module):
1252
  def create_custom_forward(module):
1253
  def custom_forward(*inputs):
1254
  return module(*inputs)
1255
-
1256
  return custom_forward
1257
 
1258
  if is_torch_version(">=", "1.11.0"):
@@ -1263,7 +1263,7 @@ class DownBlock2D(nn.Module):
1263
  hidden_states = torch.utils.checkpoint.checkpoint(
1264
  create_custom_forward(resnet), hidden_states, temb
1265
  )
1266
- else:
1267
  hidden_states = resnet(hidden_states, temb, scale=scale)
1268
 
1269
  output_states = output_states + (hidden_states,)
 
17
  import torch
18
  import torch.nn.functional as F
19
  from torch import nn
20
+
21
  from diffusers.utils import is_torch_version, logging
22
  from diffusers.utils.torch_utils import apply_freeu
23
  from diffusers.models.activations import get_activation
 
1252
  def create_custom_forward(module):
1253
  def custom_forward(*inputs):
1254
  return module(*inputs)
1255
+
1256
  return custom_forward
1257
 
1258
  if is_torch_version(">=", "1.11.0"):
 
1263
  hidden_states = torch.utils.checkpoint.checkpoint(
1264
  create_custom_forward(resnet), hidden_states, temb
1265
  )
1266
+ else:
1267
  hidden_states = resnet(hidden_states, temb, scale=scale)
1268
 
1269
  output_states = output_states + (hidden_states,)