StyleStudio / app.py
Leimingkun's picture
stylestudio
6fe0b16
raw
history blame
8.1 kB
import sys
sys.path.append("./")
import gradio as gr
import spaces
import torch
from ip_adapter.utils import BLOCKS as BLOCKS
import numpy as np
import random
from diffusers import (
AutoencoderKL,
StableDiffusionXLPipeline,
)
from ip_adapter import StyleStudio_Adapter
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if str(device).__contains__("cuda") else torch.float32
import os
os.system("git lfs install")
os.system("git clone https://huggingface.co/h94/IP-Adapter")
os.system("mv IP-Adapter/sdxl_models sdxl_models")
from huggingface_hub import hf_hub_download
# hf_hub_download(repo_id="h94/IP-Adapter", filename="sdxl_models/image_encoder", local_dir="./sdxl_models/image_encoder")
hf_hub_download(repo_id="InstantX/CSGO", filename="csgo_4_32.bin", local_dir="./CSGO/")
os.system('rm -rf IP-Adapter/models')
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
image_encoder_path = "sdxl_models/image_encoder"
csgo_ckpt ='./CSGO/csgo_4_32.bin'
pretrained_vae_name_or_path ='madebyollin/sdxl-vae-fp16-fix'
weight_dtype = torch.float16
vae = AutoencoderKL.from_pretrained(pretrained_vae_name_or_path,torch_dtype=torch.float16)
pipe = StableDiffusionXLPipeline.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
add_watermarker=False,
vae=vae
)
pipe.enable_vae_tiling()
target_style_blocks = BLOCKS['style']
csgo = StyleStudio_Adapter(
pipe, image_encoder_path, csgo_ckpt, device, num_style_tokens=32,
target_style_blocks=target_style_blocks,
controlnet_adapter=False,
style_model_resampler=True,
fuSAttn=True,
end_fusion=20,
adainIP=True,
)
MAX_SEED = np.iinfo(np.int32).max
def get_example():
case = [
[
'./assets/style1.jpg',
"Text-Driven Style Synthesis",
"A red apple",
7.0,
42,
20,
],
]
return case
def run_for_examples(style_image_pil, target, prompt, guidance_scale, seed, end_fusion):
return create_image(
style_image_pil=style_image_pil,
prompt=prompt,
guidance_scale=7.0,
num_inference_steps=50,
seed=42,
end_fusion=end_fusion,
use_SAttn=True,
crossModalAdaIN=True,
)
def randomize_seed_fn(seed: int, randomize_seed: bool) -> int:
if randomize_seed:
seed = random.randint(0, MAX_SEED)
return seed
@spaces.GPU
def create_image(
style_image_pil,
prompt,
guidance_scale,
num_inference_steps,
end_fusion,
crossModalAdaIN,
use_SAttn,
seed,
neg_prompt="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry",
):
style_image = style_image_pil
generator = torch.Generator(device).manual_seed(seed)
init_latents = torch.randn((1, 4, 128, 128), generator=generator, device="cuda", dtype=torch.float16)
num_sample=1
if use_SAttn:
num_sample=2
init_latents = init_latents.repeat(num_sample, 1, 1, 1)
with torch.no_grad():
images = csgo.generate(pil_style_image=style_image,
prompt=prompt,
negative_prompt=neg_prompt,
height=1024,
width=1024,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
num_samples=num_sample,
num_inference_steps=num_inference_steps,
end_fusion=end_fusion,
cross_modal_adain=crossModalAdaIN,
use_SAttn=use_SAttn,
generator=generator,
)
if use_SAttn:
return [images[1]]
else:
return [images[0]]
# Description
title = r"""
<h1 align="center">StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</h1>
"""
description = r"""
<b>Official πŸ€— Gradio demo</b> for <a href='https://github.com/MingKunLei/StyleStudio' target='_blank'><b>StyleStudio: Text-Driven Style Transfer with Selective Control of Style Elements</b></a>.<br>
How to use:<br>
1. Upload a style image.
2. <b>Enter your desired prompt<b>.
3. Click the <b>Submit</b> button to begin customization.
4. Share your stylized photo with your friends and enjoy! 😊
Advanced usage:<br>
1. Click advanced options.
2. Choose different guidance and steps.
3. Set the timing for the Teacher Model's participation
"""
article = r"""
---
πŸ“ **Tips**
As the value of end_fusion increases, the style gradually diminishes.
---
πŸ“ **Citation**
<br>
If our work is helpful for your research or applications, please cite us via:
```bibtex
```
πŸ“§ **Contact**
<br>
If you have any questions, please feel free to open an issue or directly reach us out at <b>leimingkun@westlake.edu.cn</b>.
"""
block = gr.Blocks(css="footer {visibility: hidden}").queue(max_size=10, api_open=False)
with block:
gr.Markdown(title)
gr.Markdown(description)
with gr.Tabs():
with gr.Row():
with gr.Column():
with gr.Row():
with gr.Column():
style_image_pil = gr.Image(label="Style Image", type='pil')
target = gr.Radio(["Text-Driven Style Synthesis"],
value="Text-Driven Style Synthesis",
label="task")
prompt = gr.Textbox(label="Prompt",
value="A red apple")
neg_prompt = gr.Textbox(label="Negative Prompt",
value="text, watermark, lowres, low quality, worst quality, deformed, glitch, low contrast, noisy, saturation, blurry")
with gr.Accordion(open=True, label="Advanced Options"):
guidance_scale = gr.Slider(minimum=1, maximum=15.0, step=0.01, value=7.0, label="guidance scale")
num_inference_steps = gr.Slider(minimum=5, maximum=100.0, step=1.0, value=50,
label="num inference steps")
end_fusion = gr.Slider(minimum=0, maximum=num_inference_steps, step=1.0, value=20.0, label="end fusion")
seed = gr.Slider(minimum=-1000000, maximum=1000000, value=1, step=1, label="Seed Value")
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
crossModalAdaIN = gr.Checkbox(label="Cross Modal AdaIN", value=True)
use_SAttn = gr.Checkbox(label="Teacher Model", value=True)
generate_button = gr.Button("Generate Image")
with gr.Column():
generated_image = gr.Gallery(label="Generated Image")
generate_button.click(
fn=randomize_seed_fn,
inputs=[seed, randomize_seed],
outputs=seed,
queue=False,
api_name=False,
).then(
fn=create_image,
inputs=[
style_image_pil,
prompt,
guidance_scale,
num_inference_steps,
end_fusion,
crossModalAdaIN,
use_SAttn,
seed,
neg_prompt,],
outputs=[generated_image])
gr.Examples(
examples=get_example(),
inputs=[style_image_pil, target, prompt, guidance_scale, seed, end_fusion],
fn=run_for_examples,
outputs=[generated_image],
cache_examples=False,
)
gr.Markdown(article)
block.launch()