Leimingkun commited on
Commit
6fe0b16
1 Parent(s): da92c10

stylestudio

Browse files
Files changed (3) hide show
  1. app.py +4 -3
  2. ip_adapter/attention_processor.py +18 -627
  3. ip_adapter/ip_adapter.py +11 -487
app.py CHANGED
@@ -85,6 +85,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
  return seed
 
88
  @spaces.GPU
89
  def create_image(
90
  style_image_pil,
@@ -95,7 +96,7 @@ def create_image(
95
  crossModalAdaIN,
96
  use_SAttn,
97
  seed,
98
- negative_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
99
  ):
100
 
101
  style_image = style_image_pil
@@ -109,7 +110,7 @@ def create_image(
109
  with torch.no_grad():
110
  images = csgo.generate(pil_style_image=style_image,
111
  prompt=prompt,
112
- negative_prompt=negative_prompt,
113
  height=1024,
114
  width=1024,
115
  guidance_scale=guidance_scale,
@@ -231,7 +232,7 @@ with block:
231
  inputs=[style_image_pil, target, prompt, guidance_scale, seed, end_fusion],
232
  fn=run_for_examples,
233
  outputs=[generated_image],
234
- cache_examples=True,
235
  )
236
 
237
  gr.Markdown(article)
 
85
  if randomize_seed:
86
  seed = random.randint(0, MAX_SEED)
87
  return seed
88
+
89
  @spaces.GPU
90
  def create_image(
91
  style_image_pil,
 
96
  crossModalAdaIN,
97
  use_SAttn,
98
  seed,
99
+ neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
100
  ):
101
 
102
  style_image = style_image_pil
 
110
  with torch.no_grad():
111
  images = csgo.generate(pil_style_image=style_image,
112
  prompt=prompt,
113
+ negative_prompt=neg_prompt,
114
  height=1024,
115
  width=1024,
116
  guidance_scale=guidance_scale,
 
232
  inputs=[style_image_pil, target, prompt, guidance_scale, seed, end_fusion],
233
  fn=run_for_examples,
234
  outputs=[generated_image],
235
+ cache_examples=False,
236
  )
237
 
238
  gr.Markdown(article)
ip_adapter/attention_processor.py CHANGED
@@ -757,441 +757,6 @@ class CNAttnProcessor2_0:
757
 
758
  return hidden_states
759
 
760
- class IP_FuAd_AttnProcessor2_0(torch.nn.Module):
761
- r"""
762
- Attention processor for IP-Adapater for PyTorch 2.0.
763
- Args:
764
- hidden_size (`int`):
765
- The hidden size of the attention layer.
766
- cross_attention_dim (`int`):
767
- The number of channels in the `encoder_hidden_states`.
768
- scale (`float`, defaults to 1.0):
769
- the weight scale of image prompt.
770
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
771
- The context length of the image features.
772
- """
773
-
774
- def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
775
- skip=False,content=False, style=False, fuAttn=False, fuIPAttn=False, adainIP=False,
776
- fuScale=0, end_fusion=0, attn_name=None):
777
- super().__init__()
778
-
779
- if not hasattr(F, "scaled_dot_product_attention"):
780
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
781
-
782
- self.hidden_size = hidden_size
783
- self.cross_attention_dim = cross_attention_dim
784
- self.content_scale = content_scale
785
- self.style_scale = style_scale
786
- self.num_style_tokens = num_style_tokens
787
- self.skip = skip
788
-
789
- self.content = content
790
- self.style = style
791
-
792
- self.fuAttn = fuAttn
793
- self.fuIPAttn = fuIPAttn
794
- self.adainIP = adainIP
795
- self.fuScale = fuScale
796
- self.denoise_step = 0
797
- self.end_fusion = end_fusion
798
- self.name = attn_name
799
-
800
- if self.content or self.style:
801
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
802
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
803
- self.to_k_ip_content =None
804
- self.to_v_ip_content =None
805
-
806
- # def set_content_ipa(self,content_scale=1.0):
807
-
808
- # self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
809
- # self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
810
- # self.content_scale=content_scale
811
- # self.content =True
812
-
813
- def reset_denoise_step(self):
814
- if self.denoise_step == 50:
815
- self.denoise_step = 0
816
- # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
817
- # print("attn2 reset successful")
818
-
819
- def __call__(
820
- self,
821
- attn,
822
- hidden_states,
823
- encoder_hidden_states=None,
824
- attention_mask=None,
825
- temb=None,
826
- ):
827
- self.denoise_step += 1
828
- residual = hidden_states
829
-
830
- if attn.spatial_norm is not None:
831
- hidden_states = attn.spatial_norm(hidden_states, temb)
832
-
833
- input_ndim = hidden_states.ndim
834
-
835
- if input_ndim == 4:
836
- batch_size, channel, height, width = hidden_states.shape
837
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
838
-
839
- batch_size, sequence_length, _ = (
840
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
841
- )
842
-
843
- if attention_mask is not None:
844
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
845
- # scaled_dot_product_attention expects attention_mask shape to be
846
- # (batch, heads, source_length, target_length)
847
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
848
-
849
- if attn.group_norm is not None:
850
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
851
-
852
- query = attn.to_q(hidden_states)
853
-
854
- if encoder_hidden_states is None:
855
- encoder_hidden_states = hidden_states
856
- else:
857
- # get encoder_hidden_states, ip_hidden_states
858
- end_pos = encoder_hidden_states.shape[1] -self.num_style_tokens
859
- encoder_hidden_states, ip_style_hidden_states = (
860
- encoder_hidden_states[:, :end_pos, :],
861
- encoder_hidden_states[:, end_pos:, :],
862
- )
863
- if attn.norm_cross:
864
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
865
-
866
- key = attn.to_k(encoder_hidden_states)
867
- value = attn.to_v(encoder_hidden_states)
868
-
869
- inner_dim = key.shape[-1]
870
- head_dim = inner_dim // attn.heads
871
-
872
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
873
-
874
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
875
-
876
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
877
-
878
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
879
- # TODO: add support for attn.scale when we move to Torch 2.1
880
- # # modified the attnMap of the Stylization Image
881
-
882
- if self.fuAttn and self.denoise_step <= self.end_fusion:
883
- assert query.shape[0] == 4
884
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
885
- text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
886
- text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
887
- text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
888
- hidden_states = torch.matmul(text_attn_probs, value)
889
- else:
890
- hidden_states = F.scaled_dot_product_attention(
891
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
892
- )
893
-
894
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
895
- hidden_states = hidden_states.to(query.dtype)
896
-
897
- raw_hidden_states = hidden_states
898
-
899
- if not self.skip and self.style is True:
900
-
901
- # for ip-style-adapter
902
- ip_style_key = self.to_k_ip(ip_style_hidden_states)
903
- ip_style_value = self.to_v_ip(ip_style_hidden_states)
904
-
905
- ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
906
- ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
907
-
908
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
909
- # TODO: add support for attn.scale when we move to Torch 2.1
910
- if self.fuIPAttn and self.denoise_step <= self.end_fusion:
911
- assert query.shape[0] == 4
912
- if "down" in self.name:
913
- print("wrong! coding")
914
- exit()
915
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
916
- ip_attn_probs = torch.matmul(query, ip_style_key.transpose(-2, -1)) * scale_factor
917
- ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
918
- ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
919
- ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
920
- ip_style_hidden_states = torch.matmul(ip_attn_probs, ip_style_value)
921
- else:
922
- ip_style_hidden_states = F.scaled_dot_product_attention(
923
- query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
924
- )
925
-
926
- ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
927
- attn.heads * head_dim)
928
- ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
929
-
930
- if not self.adainIP:
931
- hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
932
- else:
933
- # print("adain")
934
- def adain(content, style):
935
- content_mean = content.mean(dim=1, keepdim=True)
936
- content_std = content.std(dim=1, keepdim=True)
937
- style_mean = style.mean(dim=1, keepdim=True)
938
- style_std = style.std(dim=1, keepdim=True)
939
- normalized_content = (content - content_mean) / content_std
940
- stylized_content = normalized_content * style_std + style_mean
941
- return stylized_content
942
- hidden_states = adain(content=hidden_states, style=ip_style_hidden_states)
943
-
944
- if hidden_states.shape[0] == 4:
945
- hidden_states[0] = raw_hidden_states[0]
946
- hidden_states[2] = raw_hidden_states[2]
947
- # hidden_states = raw_hidden_states
948
-
949
- # linear proj
950
- hidden_states = attn.to_out[0](hidden_states)
951
- # dropout
952
- hidden_states = attn.to_out[1](hidden_states)
953
-
954
- if input_ndim == 4:
955
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
956
-
957
- if attn.residual_connection:
958
- hidden_states = hidden_states + residual
959
-
960
- hidden_states = hidden_states / attn.rescale_output_factor
961
-
962
- self.reset_denoise_step()
963
- return hidden_states
964
-
965
- class IP_FuAd_AttnProcessor2_0_exp(torch.nn.Module):
966
- r"""
967
- Attention processor for IP-Adapater for PyTorch 2.0.
968
- Args:
969
- hidden_size (`int`):
970
- The hidden size of the attention layer.
971
- cross_attention_dim (`int`):
972
- The number of channels in the `encoder_hidden_states`.
973
- scale (`float`, defaults to 1.0):
974
- the weight scale of image prompt.
975
- num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16):
976
- The context length of the image features.
977
- """
978
-
979
- def __init__(self, hidden_size, cross_attention_dim=None, content_scale=1.0,style_scale=1.0, num_content_tokens=4,num_style_tokens=4,
980
- skip=False,content=False, style=False, fuAttn=False, fuIPAttn=False, adainIP=False,
981
- fuScale=0, end_fusion=0, attn_name=None, save_attn_map=False):
982
- super().__init__()
983
-
984
- if not hasattr(F, "scaled_dot_product_attention"):
985
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
986
-
987
- self.hidden_size = hidden_size
988
- self.cross_attention_dim = cross_attention_dim
989
- self.content_scale = content_scale
990
- self.style_scale = style_scale
991
- self.num_style_tokens = num_style_tokens
992
- self.skip = skip
993
-
994
- self.content = content
995
- self.style = style
996
-
997
- self.fuAttn = fuAttn
998
- self.fuIPAttn = fuIPAttn
999
- self.adainIP = adainIP
1000
- self.fuScale = fuScale
1001
- self.denoise_step = 0
1002
- self.end_fusion = end_fusion
1003
- self.name = attn_name
1004
-
1005
- self.save_attn_map = save_attn_map
1006
-
1007
- if self.content or self.style:
1008
- self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1009
- self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1010
- self.to_k_ip_content =None
1011
- self.to_v_ip_content =None
1012
-
1013
- # def set_content_ipa(self,content_scale=1.0):
1014
-
1015
- # self.to_k_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
1016
- # self.to_v_ip_content = nn.Linear(self.cross_attention_dim or self.hidden_size, self.hidden_size, bias=False)
1017
- # self.content_scale=content_scale
1018
- # self.content =True
1019
- def reset_denoise_step(self):
1020
- if self.denoise_step == 50:
1021
- self.denoise_step = 0
1022
- # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
1023
- # print("attn2 reset successful")
1024
-
1025
- def __call__(
1026
- self,
1027
- attn,
1028
- hidden_states,
1029
- encoder_hidden_states=None,
1030
- attention_mask=None,
1031
- temb=None,
1032
- ):
1033
- self.denoise_step += 1
1034
- residual = hidden_states
1035
-
1036
- if attn.spatial_norm is not None:
1037
- hidden_states = attn.spatial_norm(hidden_states, temb)
1038
-
1039
- input_ndim = hidden_states.ndim
1040
-
1041
- if input_ndim == 4:
1042
- batch_size, channel, height, width = hidden_states.shape
1043
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1044
-
1045
- batch_size, sequence_length, _ = (
1046
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1047
- )
1048
-
1049
- if attention_mask is not None:
1050
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1051
- # scaled_dot_product_attention expects attention_mask shape to be
1052
- # (batch, heads, source_length, target_length)
1053
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1054
-
1055
- if attn.group_norm is not None:
1056
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1057
-
1058
- query = attn.to_q(hidden_states)
1059
-
1060
- if encoder_hidden_states is None:
1061
- encoder_hidden_states = hidden_states
1062
- else:
1063
- # get encoder_hidden_states, ip_hidden_states
1064
- end_pos = encoder_hidden_states.shape[1] - self.num_content_tokens-self.num_style_tokens
1065
- encoder_hidden_states, ip_style_hidden_states = (
1066
- encoder_hidden_states[:, :end_pos, :],
1067
- encoder_hidden_states[:, end_pos:, :],
1068
- )
1069
- if attn.norm_cross:
1070
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1071
-
1072
- key = attn.to_k(encoder_hidden_states)
1073
- value = attn.to_v(encoder_hidden_states)
1074
-
1075
- ## attention map
1076
- if self.save_attn_map:
1077
- attention_probs = attn.get_attention_scores(attn.head_to_batch_dim(query), attn.head_to_batch_dim(value), attention_mask)
1078
- if attention_probs is not None:
1079
- if not hasattr(attn, "attn_map"):
1080
- setattr(attn, "attn_map", {})
1081
- setattr(attn, "inference_step", 0)
1082
- else:
1083
- attn.inference_step += 1
1084
-
1085
- # # maybe we need to save all the timestep
1086
- # if attn.inference_step in self.attn_map_save_steps:
1087
- attn.attn_map[attn.inference_step] = attention_probs.clone().cpu().detach()
1088
- # attn.attn_map[attn.inference_step] = attention_probs.detach()
1089
- ## end of attention map
1090
- else:
1091
- print(f"{attn} didn't get the attention probs")
1092
-
1093
- inner_dim = key.shape[-1]
1094
- head_dim = inner_dim // attn.heads
1095
-
1096
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1097
-
1098
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1099
-
1100
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1101
-
1102
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1103
- # TODO: add support for attn.scale when we move to Torch 2.1
1104
- # # modified the attnMap of the Stylization Image
1105
-
1106
- if self.fuAttn and self.denoise_step <= self.end_fusion:
1107
- assert query.shape[0] == 4
1108
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1109
- text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1110
- text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
1111
- text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
1112
- hidden_states = torch.matmul(text_attn_probs, value)
1113
- else:
1114
- hidden_states = F.scaled_dot_product_attention(
1115
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1116
- )
1117
-
1118
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1119
- hidden_states = hidden_states.to(query.dtype)
1120
-
1121
- raw_hidden_states = hidden_states
1122
-
1123
- if not self.skip and self.style is True:
1124
-
1125
- # for ip-style-adapter
1126
- ip_style_key = self.to_k_ip(ip_style_hidden_states)
1127
- ip_style_value = self.to_v_ip(ip_style_hidden_states)
1128
-
1129
- ip_style_key = ip_style_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1130
- ip_style_value = ip_style_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1131
-
1132
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1133
- # TODO: add support for attn.scale when we move to Torch 2.1
1134
- if self.fuIPAttn and self.denoise_step <= self.end_fusion:
1135
- assert query.shape[0] == 4
1136
- if "down" in self.name:
1137
- print("wrong! coding")
1138
- exit()
1139
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1140
- ip_attn_probs = torch.matmul(query, ip_style_key.transpose(-2, -1)) * scale_factor
1141
- ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
1142
- ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
1143
- ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
1144
- ip_style_hidden_states = torch.matmul(ip_attn_probs, ip_style_value)
1145
- else:
1146
- ip_style_hidden_states = F.scaled_dot_product_attention(
1147
- query, ip_style_key, ip_style_value, attn_mask=None, dropout_p=0.0, is_causal=False
1148
- )
1149
-
1150
- ip_style_hidden_states = ip_style_hidden_states.transpose(1, 2).reshape(batch_size, -1,
1151
- attn.heads * head_dim)
1152
- ip_style_hidden_states = ip_style_hidden_states.to(query.dtype)
1153
-
1154
- # if self.adainIP and self.denoise_step >= self.start_adain:
1155
- if self.adainIP:
1156
- # print("adain")
1157
- # if self.denoise_step == 1 and "up_blocks.1.attentions.2.transformer_blocks.1" in self.name:
1158
- # print("adain")
1159
- def adain(content, style):
1160
- content_mean = content.mean(dim=1, keepdim=True)
1161
- content_std = content.std(dim=1, keepdim=True)
1162
- print("exp code")
1163
- pdb.set_trace()
1164
- style_mean = style.mean(dim=1, keepdim=True)
1165
- style_std = style.std(dim=1, keepdim=True)
1166
- normalized_content = (content - content_mean) / content_std
1167
- stylized_content = normalized_content * style_std + style_mean
1168
- return stylized_content
1169
- pdb.set_trace()
1170
- hidden_states = adain(content=hidden_states, style=ip_style_hidden_states)
1171
- else:
1172
- hidden_states = hidden_states + self.style_scale * ip_style_hidden_states
1173
-
1174
- if hidden_states.shape[0] == 4:
1175
- hidden_states[0] = raw_hidden_states[0]
1176
- hidden_states[2] = raw_hidden_states[2]
1177
- # hidden_states = raw_hidden_states
1178
-
1179
- # linear proj
1180
- hidden_states = attn.to_out[0](hidden_states)
1181
- # dropout
1182
- hidden_states = attn.to_out[1](hidden_states)
1183
-
1184
- if input_ndim == 4:
1185
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1186
-
1187
- if attn.residual_connection:
1188
- hidden_states = hidden_states + residual
1189
-
1190
- hidden_states = hidden_states / attn.rescale_output_factor
1191
-
1192
- self.reset_denoise_step()
1193
- return hidden_states
1194
-
1195
  class AttnProcessor2_0_hijack(torch.nn.Module):
1196
  r"""
1197
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
@@ -1204,131 +769,8 @@ class AttnProcessor2_0_hijack(torch.nn.Module):
1204
  save_in_unet='down',
1205
  atten_control=None,
1206
  fuSAttn=False,
1207
- fuScale=0,
1208
- end_fusion=0,
1209
- attn_name=None,
1210
- ):
1211
- super().__init__()
1212
- if not hasattr(F, "scaled_dot_product_attention"):
1213
- raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1214
- self.atten_control = atten_control
1215
- self.save_in_unet = save_in_unet
1216
-
1217
- self.fuSAttn = fuSAttn
1218
- self.fuScale = fuScale
1219
- self.denoise_step = 0
1220
- self.end_fusion = end_fusion
1221
- self.name = attn_name
1222
-
1223
- def reset_denoise_step(self):
1224
- if self.denoise_step == 50:
1225
- self.denoise_step = 0
1226
- # if "up_blocks.0.attentions.1.transformer_blocks.0.attn1" in self.name:
1227
- # print("attn1 reset successful")
1228
-
1229
- def __call__(
1230
- self,
1231
- attn,
1232
- hidden_states,
1233
- encoder_hidden_states=None,
1234
- attention_mask=None,
1235
- temb=None,
1236
- ):
1237
- self.denoise_step += 1
1238
- residual = hidden_states
1239
-
1240
- if attn.spatial_norm is not None:
1241
- hidden_states = attn.spatial_norm(hidden_states, temb)
1242
-
1243
- input_ndim = hidden_states.ndim
1244
-
1245
- if input_ndim == 4:
1246
- batch_size, channel, height, width = hidden_states.shape
1247
- hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1248
-
1249
- batch_size, sequence_length, _ = (
1250
- hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1251
- )
1252
-
1253
- if attention_mask is not None:
1254
- attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
1255
- # scaled_dot_product_attention expects attention_mask shape to be
1256
- # (batch, heads, source_length, target_length)
1257
- attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
1258
-
1259
- if attn.group_norm is not None:
1260
- hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
1261
-
1262
- query = attn.to_q(hidden_states)
1263
-
1264
- if encoder_hidden_states is None:
1265
- encoder_hidden_states = hidden_states
1266
- elif attn.norm_cross:
1267
- encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
1268
-
1269
- key = attn.to_k(encoder_hidden_states)
1270
- value = attn.to_v(encoder_hidden_states)
1271
-
1272
- inner_dim = key.shape[-1]
1273
- head_dim = inner_dim // attn.heads
1274
-
1275
- query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1276
-
1277
- key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1278
- value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1279
-
1280
- # the output of sdp = (batch, num_heads, seq_len, head_dim)
1281
- # TODO: add support for attn.scale when we move to Torch 2.1
1282
- if self.fuSAttn and self.denoise_step <= self.end_fusion:
1283
- assert query.shape[0] == 4
1284
- if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1285
- print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1286
- # pdb.set_trace()
1287
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1288
- attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1289
- attn_probs[1] = self.fuScale*attn_probs[1] + (1-self.fuScale)*attn_probs[0]
1290
- attn_probs[3] = self.fuScale*attn_probs[3] + (1-self.fuScale)*attn_probs[2]
1291
- hidden_states = torch.matmul(attn_probs, value)
1292
- else:
1293
- hidden_states = F.scaled_dot_product_attention(
1294
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1295
- )
1296
-
1297
- hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1298
- hidden_states = hidden_states.to(query.dtype)
1299
-
1300
- # linear proj
1301
- hidden_states = attn.to_out[0](hidden_states)
1302
- # dropout
1303
- hidden_states = attn.to_out[1](hidden_states)
1304
-
1305
- if input_ndim == 4:
1306
- hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1307
-
1308
- if attn.residual_connection:
1309
- hidden_states = hidden_states + residual
1310
-
1311
- hidden_states = hidden_states / attn.rescale_output_factor
1312
-
1313
- if self.denoise_step == 50:
1314
- self.reset_denoise_step()
1315
- return hidden_states
1316
-
1317
- class AttnProcessor2_0_exp(torch.nn.Module):
1318
- r"""
1319
- Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
1320
- """
1321
-
1322
- def __init__(
1323
- self,
1324
- hidden_size=None,
1325
- cross_attention_dim=None,
1326
- save_in_unet='down',
1327
- atten_control=None,
1328
- fuSAttn=False,
1329
- fuScale=0,
1330
  end_fusion=0,
1331
- attn_name=None,
1332
  ):
1333
  super().__init__()
1334
  if not hasattr(F, "scaled_dot_product_attention"):
@@ -1337,16 +779,10 @@ class AttnProcessor2_0_exp(torch.nn.Module):
1337
  self.save_in_unet = save_in_unet
1338
 
1339
  self.fuSAttn = fuSAttn
1340
- self.fuScale = fuScale
1341
  self.denoise_step = 0
1342
  self.end_fusion = end_fusion
1343
- self.name = attn_name
1344
 
1345
- def reset_denoise_step(self):
1346
- if self.denoise_step == 50:
1347
- self.denoise_step = 0
1348
- # if "up_blocks.0.attentions.1.transformer_blocks.0.attn1" in self.name:
1349
- # print("attn1 reset successful")
1350
 
1351
  def __call__(
1352
  self,
@@ -1403,26 +839,10 @@ class AttnProcessor2_0_exp(torch.nn.Module):
1403
  # TODO: add support for attn.scale when we move to Torch 2.1
1404
  if self.fuSAttn and self.denoise_step <= self.end_fusion:
1405
  assert query.shape[0] == 4
1406
- if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1407
- print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1408
- # pdb.set_trace()
1409
  scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1410
  attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1411
-
1412
- attn_probs[1] = self.fuScale*attn_probs[1] + (1-self.fuScale)*attn_probs[0]
1413
- attn_probs[3] = self.fuScale*attn_probs[3] + (1-self.fuScale)*attn_probs[2]
1414
- print("exp code")
1415
- pdb.set_trace()
1416
- def adain(content, style):
1417
- content_mean = content.mean(dim=1, keepdim=True)
1418
- content_std = content.std(dim=1, keepdim=True)
1419
- style_mean = style.mean(dim=1, keepdim=True)
1420
- style_std = style.std(dim=1, keepdim=True)
1421
- normalized_content = (content - content_mean) / content_std
1422
- stylized_content = normalized_content * style_std + style_mean
1423
- return stylized_content
1424
- value[1] = adain(content=value[0], style=value[1])
1425
- value[3] = adain(content=value[2], style=value[3])
1426
  hidden_states = torch.matmul(attn_probs, value)
1427
  else:
1428
  hidden_states = F.scaled_dot_product_attention(
@@ -1445,7 +865,8 @@ class AttnProcessor2_0_exp(torch.nn.Module):
1445
 
1446
  hidden_states = hidden_states / attn.rescale_output_factor
1447
 
1448
- self.reset_denoise_step()
 
1449
  return hidden_states
1450
 
1451
  class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
@@ -1463,7 +884,7 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1463
  """
1464
 
1465
  def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,
1466
- fuAttn=False, fuIPAttn=False, adainIP=False, end_fusion=0, fuScale=0, attn_name=None):
1467
  super().__init__()
1468
 
1469
  if not hasattr(F, "scaled_dot_product_attention"):
@@ -1478,19 +899,12 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1478
  self.fuAttn = fuAttn
1479
  self.fuIPAttn = fuIPAttn
1480
  self.adainIP = adainIP
1481
- self.denoise_step = fuScale
1482
  self.end_fusion = end_fusion
1483
- self.fuScale = fuScale
1484
- self.name = attn_name
1485
 
1486
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1487
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
1488
-
1489
- def reset_denoise_step(self):
1490
- if self.denoise_step == 50:
1491
- self.denoise_step = 0
1492
- # if "up_blocks.0.attentions.1.transformer_blocks.0.attn2" in self.name:
1493
- # print("attn2 reset successful")
1494
 
1495
  def __call__(
1496
  self,
@@ -1552,20 +966,10 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1552
 
1553
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1554
  # TODO: add support for attn.scale when we move to Torch 2.1
1555
- if self.fuAttn and self.denoise_step <= self.end_fusion:
1556
- assert query.shape[0] == 4
1557
- if "up_blocks.1.attentions.2.transformer_blocks.1" in self.name and self.denoise_step == self.end_fusion:
1558
- print("fuAttn")
1559
- print("now: ", self.denoise_step, "end now:", self.end_fusion, "scale: ", self.fuScale)
1560
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1561
- text_attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
1562
- text_attn_probs[1] = self.fuScale*text_attn_probs[1] + (1-self.fuScale)*text_attn_probs[0]
1563
- text_attn_probs[3] = self.fuScale*text_attn_probs[3] + (1-self.fuScale)*text_attn_probs[2]
1564
- hidden_states = torch.matmul(text_attn_probs, value)
1565
- else:
1566
- hidden_states = F.scaled_dot_product_attention(
1567
- query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
1568
- )
1569
 
1570
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1571
  hidden_states = hidden_states.to(query.dtype)
@@ -1582,22 +986,9 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1582
 
1583
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
1584
  # TODO: add support for attn.scale when we move to Torch 2.1
1585
- if self.fuIPAttn and self.denoise_step <= self.end_fusion:
1586
- assert query.shape[0] == 4
1587
- print("fuIPAttn")
1588
- if "down" in self.name:
1589
- print("wrong! coding")
1590
- exit()
1591
- scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
1592
- ip_attn_probs = torch.matmul(query, ip_key.transpose(-2, -1)) * scale_factor
1593
- ip_attn_probs = F.softmax(ip_attn_probs, dim=-1)
1594
- ip_attn_probs[1] = self.fuScale*ip_attn_probs[1] + (1-self.fuScale)*ip_attn_probs[0]
1595
- ip_attn_probs[3] = self.fuScale*ip_attn_probs[3] + (1-self.fuScale)*ip_attn_probs[2]
1596
- ip_hidden_states = torch.matmul(ip_attn_probs, ip_value)
1597
- else:
1598
- ip_hidden_states = F.scaled_dot_product_attention(
1599
- query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
1600
- )
1601
 
1602
  with torch.no_grad():
1603
  self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
@@ -1639,7 +1030,7 @@ class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
1639
 
1640
  hidden_states = hidden_states / attn.rescale_output_factor
1641
 
1642
- if self.denoise_step == 50:
1643
- self.reset_denoise_step()
1644
 
1645
  return hidden_states
 
757
 
758
  return hidden_states
759
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
760
  class AttnProcessor2_0_hijack(torch.nn.Module):
761
  r"""
762
  Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
 
769
  save_in_unet='down',
770
  atten_control=None,
771
  fuSAttn=False,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
  end_fusion=0,
773
+ num_inference_step=50,
774
  ):
775
  super().__init__()
776
  if not hasattr(F, "scaled_dot_product_attention"):
 
779
  self.save_in_unet = save_in_unet
780
 
781
  self.fuSAttn = fuSAttn
 
782
  self.denoise_step = 0
783
  self.end_fusion = end_fusion
784
+ self.num_inference_step=num_inference_step
785
 
 
 
 
 
 
786
 
787
  def __call__(
788
  self,
 
839
  # TODO: add support for attn.scale when we move to Torch 2.1
840
  if self.fuSAttn and self.denoise_step <= self.end_fusion:
841
  assert query.shape[0] == 4
 
 
 
842
  scale_factor = 1 / math.sqrt(torch.tensor(head_dim, dtype=query.dtype))
843
  attn_probs = (torch.matmul(query, key.transpose(-2, -1)) * scale_factor).softmax(dim=-1)
844
+ attn_probs[1] = attn_probs[0]
845
+ attn_probs[3] = attn_probs[2]
 
 
 
 
 
 
 
 
 
 
 
 
 
846
  hidden_states = torch.matmul(attn_probs, value)
847
  else:
848
  hidden_states = F.scaled_dot_product_attention(
 
865
 
866
  hidden_states = hidden_states / attn.rescale_output_factor
867
 
868
+ if self.denoise_step == self.num_inference_step:
869
+ self.denoise_step == 0
870
  return hidden_states
871
 
872
  class IPAttnProcessor2_0_cross_modal(torch.nn.Module):
 
884
  """
885
 
886
  def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4, skip=False,
887
+ fuAttn=False, fuIPAttn=False, adainIP=False, end_fusion=0, num_inference_step=50):
888
  super().__init__()
889
 
890
  if not hasattr(F, "scaled_dot_product_attention"):
 
899
  self.fuAttn = fuAttn
900
  self.fuIPAttn = fuIPAttn
901
  self.adainIP = adainIP
902
+ self.denoise_step = 0
903
  self.end_fusion = end_fusion
904
+ self.num_inference_step = num_inference_step
 
905
 
906
  self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
907
  self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
 
 
 
 
 
 
908
 
909
  def __call__(
910
  self,
 
966
 
967
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
968
  # TODO: add support for attn.scale when we move to Torch 2.1
969
+
970
+ hidden_states = F.scaled_dot_product_attention(
971
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
972
+ )
 
 
 
 
 
 
 
 
 
 
973
 
974
  hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
975
  hidden_states = hidden_states.to(query.dtype)
 
986
 
987
  # the output of sdp = (batch, num_heads, seq_len, head_dim)
988
  # TODO: add support for attn.scale when we move to Torch 2.1
989
+ ip_hidden_states = F.scaled_dot_product_attention(
990
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
991
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
992
 
993
  with torch.no_grad():
994
  self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1)
 
1030
 
1031
  hidden_states = hidden_states / attn.rescale_output_factor
1032
 
1033
+ if self.denoise_step == self.num_inference_step:
1034
+ self.denoise_step == 0
1035
 
1036
  return hidden_states
ip_adapter/ip_adapter.py CHANGED
@@ -22,8 +22,6 @@ if is_torch2_available():
22
  IPAttnProcessor2_0 as IPAttnProcessor,
23
  )
24
  from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
25
- from .attention_processor import IP_FuAd_AttnProcessor2_0 as IP_FuAd_AttnProcessor
26
- from .attention_processor import IP_FuAd_AttnProcessor2_0_exp as IP_FuAd_AttnProcessor_exp
27
  from .attention_processor import AttnProcessor2_0_exp as AttnProcessor_exp
28
  from .attention_processor import AttnProcessor2_0_hijack as AttnProcessor_hijack
29
  from .attention_processor import IPAttnProcessor2_0_cross_modal as IPAttnProcessor_cross_modal
@@ -949,7 +947,7 @@ class StyleStudio_Adapter(CSGO):
949
  if block_name in name:
950
  selected = True
951
  # print(name)
952
- attn_procs[name] = IP_FuAd_AttnProcessor(
953
  hidden_size=hidden_size,
954
  cross_attention_dim=cross_attention_dim,
955
  style_scale=1.0,
@@ -963,7 +961,7 @@ class StyleStudio_Adapter(CSGO):
963
  attn_name=name,
964
  )
965
  if selected is False:
966
- attn_procs[name] = IP_FuAd_AttnProcessor(
967
  hidden_size=hidden_size,
968
  cross_attention_dim=cross_attention_dim,
969
  num_style_tokens=self.num_style_tokens,
@@ -1011,7 +1009,7 @@ class StyleStudio_Adapter(CSGO):
1011
 
1012
  def set_scale(self, style_scale):
1013
  for attn_processor in self.pipe.unet.attn_processors.values():
1014
- if isinstance(attn_processor, IP_FuAd_AttnProcessor):
1015
  if attn_processor.style is True:
1016
  attn_processor.style_scale = style_scale
1017
  # print('style_scale:',style_scale)
@@ -1100,9 +1098,14 @@ class StyleStudio_Adapter(CSGO):
1100
  if isinstance(attn_processor, AttnProcessor_hijack):
1101
  attn_processor.fuSAttn = use_SAttn
1102
 
 
 
 
 
 
1103
  def set_adain(self, use_CMA):
1104
  for attn_processor in self.pipe.unet.attn_processors.values():
1105
- if isinstance(attn_processor, IP_FuAd_AttnProcessor):
1106
  attn_processor.adainIP = use_CMA
1107
 
1108
  def generate(
@@ -1125,6 +1128,7 @@ class StyleStudio_Adapter(CSGO):
1125
  self.set_endFusion(end_T = end_fusion)
1126
  self.set_adain(use_CMA=cross_modal_adain)
1127
  self.set_SAttn(use_SAttn=use_SAttn)
 
1128
 
1129
  # self.set_scale(style_scale=style_scale)
1130
  num_prompts = 1 if isinstance(pil_style_image, Image.Image) else len(pil_style_image)
@@ -1188,93 +1192,6 @@ class StyleStudio_Adapter(CSGO):
1188
  ).images
1189
  return images
1190
 
1191
- # StyleStudio_Adapter experiment code
1192
- class StyleStudio_Adapter_exp(StyleStudio_Adapter):
1193
- def set_ip_adapter(self):
1194
- unet = self.pipe.unet
1195
- attn_procs = {}
1196
- for name in unet.attn_processors.keys():
1197
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1198
- if name.startswith("mid_block"):
1199
- hidden_size = unet.config.block_out_channels[-1]
1200
- elif name.startswith("up_blocks"):
1201
- block_id = int(name[len("up_blocks.")])
1202
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1203
- elif name.startswith("down_blocks"):
1204
- block_id = int(name[len("down_blocks.")])
1205
- hidden_size = unet.config.block_out_channels[block_id]
1206
- if cross_attention_dim is None:
1207
- attn_procs[name] = AttnProcessor_exp(
1208
- fuSAttn=self.fuSAttn,
1209
- fuScale=self.fuScale,
1210
- end_fusion=self.end_fusion,
1211
- attn_name=name)
1212
- else:
1213
- # layername_id += 1
1214
- selected = False
1215
- for block_name in self.style_target_blocks:
1216
- if block_name in name:
1217
- selected = True
1218
- # print(name)
1219
- # 将所有的StyleBlock中的都改为FuAdAttn
1220
- attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1221
- hidden_size=hidden_size,
1222
- cross_attention_dim=cross_attention_dim,
1223
- style_scale=1.0,
1224
- style=True,
1225
- num_content_tokens=self.num_content_tokens,
1226
- num_style_tokens=self.num_style_tokens,
1227
- fuAttn=self.fuAttn,
1228
- fuIPAttn=self.fuIPAttn,
1229
- adainIP=self.adainIP,
1230
- fuScale=self.fuScale,
1231
- end_fusion=self.end_fusion,
1232
- attn_name=name,
1233
- save_attn_map=self.save_attn_map,
1234
- )
1235
- # 没有CSGO中关于Content Control的需求 因此就将这个处理Content tokens Cross Attention 删除
1236
- # 并且这里应该是CSGO代码中 有问题的部分 不论如何这里都会被之后的重置
1237
- # 并且在CSGO的设计里Content Block和Style Block是没有子集的
1238
- # selected False表明不是Style Block 关键是 Skip = True
1239
- if selected is False:
1240
- attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1241
- hidden_size=hidden_size,
1242
- cross_attention_dim=cross_attention_dim,
1243
- num_content_tokens=self.num_content_tokens,
1244
- num_style_tokens=self.num_style_tokens,
1245
- skip=True,
1246
- fuAttn=self.fuAttn,
1247
- fuIPAttn=self.fuIPAttn,
1248
- adainIP=self.adainIP,
1249
- fuScale=self.fuScale,
1250
- end_fusion=self.end_fusion,
1251
- attn_name=name,
1252
- save_attn_map=self.save_attn_map,
1253
- )
1254
- # attn_procs[name] = IP_FuAd_AttnProcessor_exp(
1255
- # hidden_size=hidden_size,
1256
- # cross_attention_dim=cross_attention_dim,
1257
- # num_content_tokens=self.num_content_tokens,
1258
- # num_style_tokens=self.num_style_tokens,
1259
- # skip=True,
1260
- # fuAttn=self.fuAttn,
1261
- # fuIPAttn=self.fuIPAttn,
1262
- # )
1263
-
1264
- attn_procs[name].to(self.device, dtype=torch.float16)
1265
- unet.set_attn_processor(attn_procs)
1266
- if hasattr(self.pipe, "controlnet"):
1267
- if self.controlnet_adapter is False:
1268
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
1269
- for controlnet in self.pipe.controlnet.nets:
1270
- controlnet.set_attn_processor(CNAttnProcessor(
1271
- num_tokens=self.num_content_tokens + self.num_style_tokens))
1272
- else:
1273
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(
1274
- num_tokens=self.num_content_tokens + self.num_style_tokens))
1275
- # 因为我们的代码中没有controlnet需要将Style 注入 这并不是一个I2I的任务
1276
- # 因此我们将原本CSGO中和ControlNet中注入Style的部分给删除了
1277
-
1278
  class IPAdapterXL(IPAdapter):
1279
  """SDXL"""
1280
 
@@ -1361,397 +1278,4 @@ class IPAdapterXL(IPAdapter):
1361
  **kwargs,
1362
  ).images
1363
 
1364
- return images
1365
-
1366
-
1367
- class IPAdapterXL_cross_modal(IPAdapterXL):
1368
- def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4,
1369
- target_blocks=["block"],
1370
- fuAttn=False,
1371
- fuSAttn=False,
1372
- fuIPAttn=False,
1373
- fuScale=0,
1374
- adainIP=False,
1375
- end_fusion=0,
1376
- save_attn_map=False,):
1377
- self.fuAttn = fuAttn
1378
- self.fuSAttn = fuSAttn
1379
- self.fuIPAttn = fuIPAttn
1380
- self.adainIP = adainIP
1381
- self.fuScale = fuScale
1382
- if self.fuSAttn:
1383
- print(f"hijack Self AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1384
- if self.fuAttn:
1385
- print(f"hijack Cross AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1386
- if self.fuIPAttn:
1387
- print(f"hijack IP AttnMap in {end_fusion} steps", "fuScale is: ", fuScale)
1388
- self.end_fusion = end_fusion
1389
- self.save_attn_map = save_attn_map
1390
-
1391
- self.device = device
1392
- self.image_encoder_path = image_encoder_path
1393
- self.ip_ckpt = ip_ckpt
1394
- self.num_tokens = num_tokens
1395
- self.target_blocks = target_blocks
1396
-
1397
- self.pipe = sd_pipe.to(self.device)
1398
- self.set_ip_adapter()
1399
-
1400
- # load image encoder
1401
- self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
1402
- self.device, dtype=torch.float16
1403
- )
1404
- self.clip_image_processor = CLIPImageProcessor()
1405
- # image proj model
1406
- self.image_proj_model = self.init_proj()
1407
-
1408
- self.load_ip_adapter()
1409
-
1410
- def init_proj(self):
1411
- image_proj_model = ImageProjModel(
1412
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1413
- clip_embeddings_dim=self.image_encoder.config.projection_dim,
1414
- clip_extra_context_tokens=self.num_tokens,
1415
- ).to(self.device, dtype=torch.float16)
1416
- return image_proj_model
1417
-
1418
- def set_ip_adapter(self):
1419
- unet = self.pipe.unet
1420
- attn_procs = {}
1421
- for name in unet.attn_processors.keys():
1422
- cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
1423
- if name.startswith("mid_block"):
1424
- hidden_size = unet.config.block_out_channels[-1]
1425
- elif name.startswith("up_blocks"):
1426
- block_id = int(name[len("up_blocks.")])
1427
- hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
1428
- elif name.startswith("down_blocks"):
1429
- block_id = int(name[len("down_blocks.")])
1430
- hidden_size = unet.config.block_out_channels[block_id]
1431
- if cross_attention_dim is None:
1432
- attn_procs[name] = AttnProcessor_hijack(
1433
- fuSAttn=self.fuSAttn,
1434
- fuScale=self.fuScale,
1435
- end_fusion=self.end_fusion,
1436
- attn_name=name) # Self Attention
1437
- else: # Cross Attention
1438
- selected = False
1439
- for block_name in self.target_blocks:
1440
- if block_name in name:
1441
- selected = True
1442
- break
1443
- if selected:
1444
- attn_procs[name] = IPAttnProcessor_cross_modal(
1445
- hidden_size=hidden_size,
1446
- cross_attention_dim=cross_attention_dim,
1447
- scale=1.0,
1448
- num_tokens=self.num_tokens,
1449
- fuAttn=self.fuAttn,
1450
- fuIPAttn=self.fuIPAttn,
1451
- adainIP=self.adainIP,
1452
- fuScale=self.fuScale,
1453
- end_fusion=self.end_fusion,
1454
- attn_name=name,
1455
- ).to(self.device, dtype=torch.float16)
1456
- else:
1457
- attn_procs[name] = IPAttnProcessor_cross_modal(
1458
- hidden_size=hidden_size,
1459
- cross_attention_dim=cross_attention_dim,
1460
- scale=1.0,
1461
- num_tokens=self.num_tokens,
1462
- skip=True,
1463
- fuAttn=self.fuAttn,
1464
- fuIPAttn=self.fuIPAttn,
1465
- adainIP=self.adainIP,
1466
- fuScale=self.fuScale,
1467
- end_fusion=self.end_fusion,
1468
- attn_name=name,
1469
- ).to(self.device, dtype=torch.float16)
1470
- unet.set_attn_processor(attn_procs)
1471
- if hasattr(self.pipe, "controlnet"):
1472
- if isinstance(self.pipe.controlnet, MultiControlNetModel):
1473
- for controlnet in self.pipe.controlnet.nets:
1474
- controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
1475
- else:
1476
- self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
1477
-
1478
- def load_ip_adapter(self):
1479
- if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
1480
- state_dict = {"image_proj": {}, "ip_adapter": {}}
1481
- with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
1482
- for key in f.keys():
1483
- if key.startswith("image_proj."):
1484
- state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
1485
- elif key.startswith("ip_adapter."):
1486
- state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
1487
- else:
1488
- state_dict = torch.load(self.ip_ckpt, map_location="cpu")
1489
- self.image_proj_model.load_state_dict(state_dict["image_proj"])
1490
- ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
1491
- ip_layers.load_state_dict(state_dict["ip_adapter"], strict=False)
1492
-
1493
- @torch.inference_mode()
1494
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
1495
- if pil_image is not None:
1496
- if isinstance(pil_image, Image.Image):
1497
- pil_image = [pil_image]
1498
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1499
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1500
- else:
1501
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
1502
-
1503
- if content_prompt_embeds is not None:
1504
- clip_image_embeds = clip_image_embeds - content_prompt_embeds
1505
-
1506
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1507
- uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
1508
- return image_prompt_embeds, uncond_image_prompt_embeds
1509
-
1510
- def set_scale(self, scale):
1511
- for attn_processor in self.pipe.unet.attn_processors.values():
1512
- if isinstance(attn_processor, IPAttnProcessor_cross_modal):
1513
- attn_processor.scale = scale
1514
-
1515
- @torch.inference_mode()
1516
- def get_neg_image_embeds(self, pil_image=None, clip_image_embeds=None, content_prompt_embeds=None):
1517
- if pil_image is not None:
1518
- if isinstance(pil_image, Image.Image):
1519
- pil_image = [pil_image]
1520
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1521
- clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
1522
- else:
1523
- clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
1524
-
1525
- if content_prompt_embeds is not None:
1526
- clip_image_embeds = clip_image_embeds - content_prompt_embeds
1527
-
1528
- neg_image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1529
- return neg_image_prompt_embeds
1530
-
1531
- def generate(
1532
- self,
1533
- pil_image,
1534
- neg_pil_image=None,
1535
- prompt=None,
1536
- negative_prompt=None,
1537
- scale=1.0,
1538
- num_samples=4,
1539
- seed=None,
1540
- num_inference_steps=30,
1541
- neg_content_emb=None,
1542
- neg_content_prompt=None,
1543
- neg_content_scale=1.0,
1544
- **kwargs,
1545
- ):
1546
- self.set_scale(scale)
1547
-
1548
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1549
-
1550
- if prompt is None:
1551
- prompt = "best quality, high quality"
1552
- if negative_prompt is None:
1553
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1554
-
1555
- if not isinstance(prompt, List):
1556
- prompt = [prompt] * num_prompts
1557
- if not isinstance(negative_prompt, List):
1558
- negative_prompt = [negative_prompt] * num_prompts
1559
-
1560
- if neg_content_emb is None:
1561
- if neg_content_prompt is not None:
1562
- with torch.inference_mode():
1563
- (
1564
- prompt_embeds_, # torch.Size([1, 77, 2048])
1565
- negative_prompt_embeds_,
1566
- pooled_prompt_embeds_, # torch.Size([1, 1280])
1567
- negative_pooled_prompt_embeds_,
1568
- ) = self.pipe.encode_prompt(
1569
- neg_content_prompt,
1570
- num_images_per_prompt=num_samples,
1571
- do_classifier_free_guidance=True,
1572
- negative_prompt=negative_prompt,
1573
- )
1574
- pooled_prompt_embeds_ *= neg_content_scale
1575
- else:
1576
- pooled_prompt_embeds_ = neg_content_emb
1577
- else:
1578
- pooled_prompt_embeds_ = None
1579
-
1580
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image, content_prompt_embeds=pooled_prompt_embeds_)
1581
-
1582
- if neg_pil_image is not None:
1583
- neg_image_prompt_embeds = self.get_neg_image_embeds(neg_pil_image)
1584
- cos_sim_neg = F.cosine_similarity(image_prompt_embeds, neg_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1585
- cos_sim_uncond = F.cosine_similarity(image_prompt_embeds, uncond_image_prompt_embeds.squeeze(0).unsqueeze(1), dim=-1)
1586
- print(f"neg cos sim is: {cos_sim_neg.diagonal()}")
1587
- print(f"uncond cos sim is: {cos_sim_uncond.diagonal()}")
1588
- uncond_image_prompt_embeds = neg_image_prompt_embeds
1589
-
1590
- bs_embed, seq_len, _ = image_prompt_embeds.shape
1591
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1592
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1593
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1594
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1595
-
1596
- with torch.inference_mode():
1597
- (
1598
- prompt_embeds,
1599
- negative_prompt_embeds,
1600
- pooled_prompt_embeds,
1601
- negative_pooled_prompt_embeds,
1602
- ) = self.pipe.encode_prompt(
1603
- prompt,
1604
- num_images_per_prompt=num_samples,
1605
- do_classifier_free_guidance=True,
1606
- negative_prompt=negative_prompt,
1607
- )
1608
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1609
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1610
-
1611
- # self.generator = get_generator(seed, self.device)
1612
-
1613
- images = self.pipe(
1614
- prompt_embeds=prompt_embeds,
1615
- negative_prompt_embeds=negative_prompt_embeds,
1616
- pooled_prompt_embeds=pooled_prompt_embeds,
1617
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1618
- num_inference_steps=num_inference_steps,
1619
- # generator=self.generator,
1620
- **kwargs,
1621
- ).images
1622
-
1623
- return images
1624
-
1625
-
1626
- class IPAdapterPlus(IPAdapter):
1627
- """IP-Adapter with fine-grained features"""
1628
-
1629
- def init_proj(self):
1630
- image_proj_model = Resampler(
1631
- dim=self.pipe.unet.config.cross_attention_dim,
1632
- depth=4,
1633
- dim_head=64,
1634
- heads=12,
1635
- num_queries=self.num_tokens,
1636
- embedding_dim=self.image_encoder.config.hidden_size,
1637
- output_dim=self.pipe.unet.config.cross_attention_dim,
1638
- ff_mult=4,
1639
- ).to(self.device, dtype=torch.float16)
1640
- return image_proj_model
1641
-
1642
- @torch.inference_mode()
1643
- def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
1644
- if isinstance(pil_image, Image.Image):
1645
- pil_image = [pil_image]
1646
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1647
- clip_image = clip_image.to(self.device, dtype=torch.float16)
1648
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1649
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1650
- uncond_clip_image_embeds = self.image_encoder(
1651
- torch.zeros_like(clip_image), output_hidden_states=True
1652
- ).hidden_states[-2]
1653
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1654
- return image_prompt_embeds, uncond_image_prompt_embeds
1655
-
1656
-
1657
- class IPAdapterFull(IPAdapterPlus):
1658
- """IP-Adapter with full features"""
1659
-
1660
- def init_proj(self):
1661
- image_proj_model = MLPProjModel(
1662
- cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
1663
- clip_embeddings_dim=self.image_encoder.config.hidden_size,
1664
- ).to(self.device, dtype=torch.float16)
1665
- return image_proj_model
1666
-
1667
-
1668
- class IPAdapterPlusXL(IPAdapter):
1669
- """SDXL"""
1670
-
1671
- def init_proj(self):
1672
- image_proj_model = Resampler(
1673
- dim=1280,
1674
- depth=4,
1675
- dim_head=64,
1676
- heads=20,
1677
- num_queries=self.num_tokens,
1678
- embedding_dim=self.image_encoder.config.hidden_size,
1679
- output_dim=self.pipe.unet.config.cross_attention_dim,
1680
- ff_mult=4,
1681
- ).to(self.device, dtype=torch.float16)
1682
- return image_proj_model
1683
-
1684
- @torch.inference_mode()
1685
- def get_image_embeds(self, pil_image):
1686
- if isinstance(pil_image, Image.Image):
1687
- pil_image = [pil_image]
1688
- clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
1689
- clip_image = clip_image.to(self.device, dtype=torch.float16)
1690
- clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
1691
- image_prompt_embeds = self.image_proj_model(clip_image_embeds)
1692
- uncond_clip_image_embeds = self.image_encoder(
1693
- torch.zeros_like(clip_image), output_hidden_states=True
1694
- ).hidden_states[-2]
1695
- uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
1696
- return image_prompt_embeds, uncond_image_prompt_embeds
1697
-
1698
- def generate(
1699
- self,
1700
- pil_image,
1701
- prompt=None,
1702
- negative_prompt=None,
1703
- scale=1.0,
1704
- num_samples=4,
1705
- seed=None,
1706
- num_inference_steps=30,
1707
- **kwargs,
1708
- ):
1709
- self.set_scale(scale)
1710
-
1711
- num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
1712
-
1713
- if prompt is None:
1714
- prompt = "best quality, high quality"
1715
- if negative_prompt is None:
1716
- negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
1717
-
1718
- if not isinstance(prompt, List):
1719
- prompt = [prompt] * num_prompts
1720
- if not isinstance(negative_prompt, List):
1721
- negative_prompt = [negative_prompt] * num_prompts
1722
-
1723
- image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
1724
- bs_embed, seq_len, _ = image_prompt_embeds.shape
1725
- image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
1726
- image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1727
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
1728
- uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
1729
-
1730
- with torch.inference_mode():
1731
- (
1732
- prompt_embeds,
1733
- negative_prompt_embeds,
1734
- pooled_prompt_embeds,
1735
- negative_pooled_prompt_embeds,
1736
- ) = self.pipe.encode_prompt(
1737
- prompt,
1738
- num_images_per_prompt=num_samples,
1739
- do_classifier_free_guidance=True,
1740
- negative_prompt=negative_prompt,
1741
- )
1742
- prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1743
- negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1744
-
1745
- generator = get_generator(seed, self.device)
1746
-
1747
- images = self.pipe(
1748
- prompt_embeds=prompt_embeds,
1749
- negative_prompt_embeds=negative_prompt_embeds,
1750
- pooled_prompt_embeds=pooled_prompt_embeds,
1751
- negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
1752
- num_inference_steps=num_inference_steps,
1753
- generator=generator,
1754
- **kwargs,
1755
- ).images
1756
-
1757
- return images
 
22
  IPAttnProcessor2_0 as IPAttnProcessor,
23
  )
24
  from .attention_processor import IP_CS_AttnProcessor2_0 as IP_CS_AttnProcessor
 
 
25
  from .attention_processor import AttnProcessor2_0_exp as AttnProcessor_exp
26
  from .attention_processor import AttnProcessor2_0_hijack as AttnProcessor_hijack
27
  from .attention_processor import IPAttnProcessor2_0_cross_modal as IPAttnProcessor_cross_modal
 
947
  if block_name in name:
948
  selected = True
949
  # print(name)
950
+ attn_procs[name] = IPAttnProcessor_cross_modal(
951
  hidden_size=hidden_size,
952
  cross_attention_dim=cross_attention_dim,
953
  style_scale=1.0,
 
961
  attn_name=name,
962
  )
963
  if selected is False:
964
+ attn_procs[name] = IPAttnProcessor_cross_modal(
965
  hidden_size=hidden_size,
966
  cross_attention_dim=cross_attention_dim,
967
  num_style_tokens=self.num_style_tokens,
 
1009
 
1010
  def set_scale(self, style_scale):
1011
  for attn_processor in self.pipe.unet.attn_processors.values():
1012
+ if isinstance(attn_processor, IPAttnProcessor_cross_modal):
1013
  if attn_processor.style is True:
1014
  attn_processor.style_scale = style_scale
1015
  # print('style_scale:',style_scale)
 
1098
  if isinstance(attn_processor, AttnProcessor_hijack):
1099
  attn_processor.fuSAttn = use_SAttn
1100
 
1101
+ def set_num_inference_step(self, num_T):
1102
+ for attn_processor in self.pipe.unet.attn_processors.values():
1103
+ if isinstance(attn_processor, AttnProcessor_hijack) or isinstance(attn_processor, IPAttnProcessor_cross_modal):
1104
+ attn_processor.num_inference_step = num_T
1105
+
1106
  def set_adain(self, use_CMA):
1107
  for attn_processor in self.pipe.unet.attn_processors.values():
1108
+ if isinstance(attn_processor, IPAttnProcessor_cross_modal):
1109
  attn_processor.adainIP = use_CMA
1110
 
1111
  def generate(
 
1128
  self.set_endFusion(end_T = end_fusion)
1129
  self.set_adain(use_CMA=cross_modal_adain)
1130
  self.set_SAttn(use_SAttn=use_SAttn)
1131
+ self.set_num_inference_step(num_T=num_inference_steps)
1132
 
1133
  # self.set_scale(style_scale=style_scale)
1134
  num_prompts = 1 if isinstance(pil_style_image, Image.Image) else len(pil_style_image)
 
1192
  ).images
1193
  return images
1194
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1195
  class IPAdapterXL(IPAdapter):
1196
  """SDXL"""
1197
 
 
1278
  **kwargs,
1279
  ).images
1280
 
1281
+ return images