Leimingkun commited on
Commit
835dcb7
1 Parent(s): f343ea1

stylestudio

Browse files
Files changed (3) hide show
  1. .gitattributes +1 -2
  2. app.py +1 -0
  3. app_exp.py +0 -244
.gitattributes CHANGED
@@ -32,5 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- app_exp.py
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
 
app.py CHANGED
@@ -100,6 +100,7 @@ def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
100
  seed = random.randint(0, MAX_SEED)
101
  return seed
102
 
 
103
  def create_image(style_image_pil,
104
  prompt,
105
  neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
 
100
  seed = random.randint(0, MAX_SEED)
101
  return seed
102
 
103
+ @spaces.GPU
104
  def create_image(style_image_pil,
105
  prompt,
106
  neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
app_exp.py DELETED
@@ -1,244 +0,0 @@
1
- import sys
2
- sys.path.append("./")
3
- import gradio as gr
4
- import spaces
5
- import torch
6
- from ip_adapter.utils import BLOCKS as BLOCKS
7
- import numpy as np
8
- import random
9
- from diffusers import (
10
- AutoencoderKL,
11
- StableDiffusionXLPipeline,
12
- )
13
- from ip_adapter import StyleStudio_Adapter
14
-
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
17
- base_model_path = "/mnt/agilab/models/sdxl"
18
- image_encoder_path = "/mnt/agilab/models/ipadapter_sdxl/image_encoder"
19
- csgo_ckpt = "/mnt/agilab/models/CSGO/csgo_4_32.bin"
20
- pretrained_vae_name_or_path = '/mnt/agilab/models/madebyollin_sdxl-vae-fp16-fix'
21
- weight_dtype = torch.float16
22
-
23
- vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
24
- pipe = StableDiffusionXLPipeline.from_pretrained(
25
- base_model_path,
26
- torch_dtype=torch.float16,
27
- add_watermarker=False,
28
- vae=vae
29
- )
30
- pipe.enable_vae_tiling()
31
-
32
- target_style_blocks = BLOCKS['style']
33
-
34
- csgo = StyleStudio_Adapter(
35
- pipe, image_encoder_path, csgo_ckpt, device, num_style_tokens=32,
36
- target_style_blocks=target_style_blocks,
37
- controlnet_adapter=False,
38
- style_model_resampler=True,
39
-
40
- fuSAttn=True,
41
- end_fusion=20,
42
- adainIP=True,
43
- )
44
-
45
- MAX_SEED = np.iinfo(np.int32).max
46
-
47
-
48
- def get_example():
49
- case = [
50
- [
51
- './assets/style1.jpg',
52
- "A red apple",
53
- 7.0,
54
- 42,
55
- 10,
56
- ],
57
- [
58
- './assets/style2.jpg',
59
- "A black car",
60
- 7.0,
61
- 42,
62
- 10,
63
- ],
64
- [
65
- './assets/style3.jpg',
66
- "A orange bus",
67
- 7.0,
68
- 42,
69
- 10,
70
- ],
71
- ]
72
- return case
73
-
74
- def run_for_examples(style_image_pil, prompt, guidance_scale, seed, end_fusion):
75
-
76
- return create_image(
77
- style_image_pil=style_image_pil,
78
- prompt=prompt,
79
- neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
80
- guidance_scale=guidance_scale,
81
- num_inference_steps=50,
82
- seed=seed,
83
- end_fusion=end_fusion,
84
- use_SAttn=True,
85
- crossModalAdaIN=True,
86
- )
87
-
88
- def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
89
- if randomize_seed:
90
- seed = random.randint(0, MAX_SEED)
91
- return seed
92
-
93
- def create_image(style_image_pil,
94
- prompt,
95
- neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
96
- guidance_scale=7,
97
- num_inference_steps=50,
98
- end_fusion=20,
99
- crossModalAdaIN=True,
100
- use_SAttn=True,
101
- seed=42,
102
- ):
103
-
104
- style_image = style_image_pil
105
-
106
- generator = torch.Generator(device).manual_seed(seed)
107
- init_latents = torch.randn((1, 4, 128, 128), generator=generator, device="cuda", dtype=torch.float16)
108
- num_sample=1
109
- if use_SAttn:
110
- num_sample=2
111
- init_latents = init_latents.repeat(num_sample, 1, 1, 1)
112
- with torch.no_grad():
113
- images = csgo.generate(pil_style_image=style_image,
114
- prompt=prompt,
115
- negative_prompt=neg_prompt,
116
- height=1024,
117
- width=1024,
118
- guidance_scale=guidance_scale,
119
- num_images_per_prompt=1,
120
- num_samples=num_sample,
121
- num_inference_steps=num_inference_steps,
122
- end_fusion=end_fusion,
123
- cross_modal_adain=crossModalAdaIN,
124
- use_SAttn=use_SAttn,
125
-
126
- generator=generator,
127
- latents=init_latents,
128
- )
129
-
130
- if use_SAttn:
131
- return [images[1]]
132
- else:
133
- return [images[0]]
134
-
135
- # Description
136
- title = r"""
137
- <h1 align="center">StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</h1>
138
- """
139
-
140
- description = r"""
141
- <b>Official 🤗 Gradio demo</b> for <a href='https://github.com/Westlake-AGI-Lab/StyleStudio' target='_blank'><b>StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</b></a>.<br>
142
- How to use:<br>
143
- 1. Upload a style image.
144
- 2. <b>Enter your desired prompt</b>.
145
- 3. Click the <b>Submit</b> button to begin customization.
146
- 4. Share your stylized photo with your friends and enjoy! 😊
147
-
148
- Advanced usage:<br>
149
- 1. Click advanced options.
150
- 2. Choose different guidance and steps.
151
- 3. Set the timing for the Teacher Model's participation.
152
- 4.
153
- """
154
-
155
- article = r"""
156
- ---
157
- 📝 **Tips**
158
- <br>
159
- 1. As the value of end_fusion <b>increases</b>, the style gradually diminishes.
160
- Therefore, it is suggested to set end_fusion to be between 1/5 and 1/3 of the number of inference steps (num inference steps).
161
- 2. If you want to experience style-based CFG, see the details on the <a href="https://github.com/Westlake-AGI-Lab/StyleStudio">GitHub repo</a>.
162
-
163
- ---
164
- 📝 **Citation**
165
- <br>
166
- If our work is helpful for your research or applications, please cite us via:
167
- ```bibtex
168
-
169
- ```
170
- 📧 **Contact**
171
- <br>
172
- If you have any questions, please feel free to open an issue or directly reach us out at <b>leimingkun@westlake.edu.cn</b>.
173
- """
174
-
175
- block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
176
- with block:
177
- gr.Markdown(title)
178
- gr.Markdown(description)
179
-
180
- with gr.Tabs():
181
- with gr.Row():
182
- with gr.Column():
183
- with gr.Row():
184
- with gr.Column():
185
- style_image_pil = gr.Image(label="Style Image", type='pil')
186
-
187
- prompt = gr.Textbox(label="Prompt",
188
- value="A red apple")
189
-
190
- neg_prompt = gr.Textbox(label="Negative Prompt",
191
- value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
192
-
193
- with gr.Accordion(open=True, label="Advanced Options"):
194
-
195
- guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
196
-
197
- num_inference_steps = gr.Slider(minimum=5, maximum=200.0, step=1.0, value=50,
198
- label="num inference steps")
199
-
200
- end_fusion = gr.Slider(minimum=0, maximum=200, step=1.0, value=20.0, label="end fusion")
201
-
202
- seed = gr.Slider(minimum=-1000000, maximum=1000000, value=42, step=1, label="Seed Value")
203
-
204
- randomize_seed = gr.Checkbox(label="Randomize seed", value=False)
205
-
206
- crossModalAdaIN = gr.Checkbox(label="Cross Modal AdaIN", value=True)
207
- use_SAttn = gr.Checkbox(label="Teacher Model", value=True)
208
-
209
- generate_button = gr.Button("Generate Image")
210
-
211
- with gr.Column():
212
- generated_image = gr.Gallery(label="Generated Image")
213
-
214
- generate_button.click(
215
- fn=randomize_seed_fn,
216
- inputs=[seed, randomize_seed],
217
- outputs=seed,
218
- queue=False,
219
- api_name=False,
220
- ).then(
221
- fn=create_image,
222
- inputs=[
223
- style_image_pil,
224
- prompt,
225
- neg_prompt,
226
- guidance_scale,
227
- num_inference_steps,
228
- end_fusion,
229
- crossModalAdaIN,
230
- use_SAttn,
231
- seed,],
232
- outputs=[generated_image])
233
-
234
- gr.Examples(
235
- examples=get_example(),
236
- inputs=[style_image_pil, prompt, guidance_scale, seed, end_fusion],
237
- fn=run_for_examples,
238
- outputs=[generated_image],
239
- cache_examples=False,
240
- )
241
-
242
- gr.Markdown(article)
243
-
244
- block.launch(server_name="0.0.0.0", server_port=1234)