File size: 6,786 Bytes
9887d4c 7696de6 9887d4c 39e79b5 7696de6 9887d4c 7696de6 28fa58e af079bb b93d27c cb9c510 7696de6 9d21a93 7696de6 b93d27c 7696de6 af079bb 39e79b5 9887d4c 28fa58e 5e12bb2 9887d4c 7696de6 39e79b5 9d21a93 8c21422 4bacbb7 7696de6 9887d4c 39e79b5 7696de6 5072f90 7696de6 9887d4c 5072f90 9887d4c 15bfa60 35e3bb1 8558373 9887d4c 7696de6 9887d4c fbdf399 2935cbc 9887d4c 6c25594 9887d4c 6c25594 15bfa60 6c25594 cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c cb9c510 7696de6 7e829d4 cb9c510 7e829d4 cb9c510 7e829d4 cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c cb9c510 9887d4c 099c99b 9887d4c cb9c510 9887d4c 099c99b 9887d4c cb9c510 9887d4c 099c99b 9887d4c cb9c510 9887d4c 7696de6 9887d4c 3da8c6a 9887d4c 7e829d4 9887d4c 7696de6 8245d19 7696de6 5ddbee5 9887d4c f8ac431 7696de6 03a8244 7696de6 9887d4c cb9c510 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 |
import gradio as gr
import numpy as np
import random
import torch
from PIL import Image
import os
from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor, pipeline
from kolors.pipelines.pipeline_stable_diffusion_xl_chatglm_256_ipadapter import StableDiffusionXLPipeline
from kolors.models.modeling_chatglm import ChatGLMModel
from kolors.models.tokenization_chatglm import ChatGLMTokenizer
from kolors.models.unet_2d_condition import UNet2DConditionModel
from diffusers import AutoencoderKL, EulerDiscreteScheduler
from huggingface_hub import snapshot_download
import spaces
device = "cuda"
root_dir = os.getcwd()
ckpt_dir = f'{root_dir}/weights/Kolors'
snapshot_download(repo_id="Kwai-Kolors/Kolors", local_dir=ckpt_dir)
snapshot_download(repo_id="Kwai-Kolors/Kolors-IP-Adapter-Plus", local_dir=f"{root_dir}/weights/Kolors-IP-Adapter-Plus")
# Load models
text_encoder = ChatGLMModel.from_pretrained(f'{ckpt_dir}/text_encoder', torch_dtype=torch.float16).half().to(device)
tokenizer = ChatGLMTokenizer.from_pretrained(f'{ckpt_dir}/text_encoder')
vae = AutoencoderKL.from_pretrained(f"{ckpt_dir}/vae", revision=None).half().to(device)
scheduler = EulerDiscreteScheduler.from_pretrained(f"{ckpt_dir}/scheduler")
unet = UNet2DConditionModel.from_pretrained(f"{ckpt_dir}/unet", revision=None).half().to(device)
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
f'{root_dir}/weights/Kolors-IP-Adapter-Plus/image_encoder',
ignore_mismatched_sizes=True
).to(dtype=torch.float16, device=device)
ip_img_size = 336
clip_image_processor = CLIPImageProcessor(size=ip_img_size, crop_size=ip_img_size)
pipe = StableDiffusionXLPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
image_encoder=image_encoder,
feature_extractor=clip_image_processor,
force_zeros_for_empty_prompt=False
).to(device)
if hasattr(pipe.unet, 'encoder_hid_proj'):
pipe.unet.text_encoder_hid_proj = pipe.unet.encoder_hid_proj
pipe.load_ip_adapter(f'{root_dir}/weights/Kolors-IP-Adapter-Plus', subfolder="", weight_name=["ip_adapter_plus_general.bin"])
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en")
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU
def infer(prompt, ip_adapter_image, ip_adapter_scale=0.5, negative_prompt="", seed=100, randomize_seed=False, width=1024, height=1024, guidance_scale=5.0, num_inference_steps=50, progress=gr.Progress(track_tqdm=True)):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Translate prompt if it's in Korean
translated_prompt = translator(prompt, src_lang="ko", tgt_lang="en")[0]['translation_text']
generator = torch.Generator(device="cuda").manual_seed(seed)
pipe.to("cuda")
image_encoder.to("cuda")
pipe.image_encoder = image_encoder
pipe.set_ip_adapter_scale([ip_adapter_scale])
image = pipe(
prompt=translated_prompt,
ip_adapter_image=[ip_adapter_image],
negative_prompt=negative_prompt,
height=height,
width=width,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=1,
generator=generator,
).images[0]
return image, seed
examples = [
["์ถค์ ์ถ์ด๋ผ", "woman.png", 0.4],
["๊ฐ์์ง", "minta.jpeg", 0.4],
["ํํ๊ฒ ์์ด๋ผ", "trump.png", 0.5],
["์ฌ๋นผ๋ฏธ", "forest.png", 0.5],
["", "meow.jpeg", 1.0],
]
css="""
#col-container {
margin: 0 auto;
max-width: 720px;
}
#result img{
object-position: top;
}
#result .image-container{
height: 100%
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""
# ํ๊ธ๋ก ์ด๋ฏธ์ง ๋ณํ ๋ฐ ์์ฑ ์๋น์ค @ https://discord.gg/openfreeai
""")
with gr.Row():
prompt = gr.Text(
label="ํ๋กฌํํธ",
show_label=False,
max_lines=1,
placeholder="ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ์ธ์",
container=False,
)
run_button = gr.Button("์คํ", scale=0)
with gr.Row():
with gr.Column():
ip_adapter_image = gr.Image(label="IP-์ด๋ํฐ ์ด๋ฏธ์ง", type="pil")
ip_adapter_scale = gr.Slider(
label="์ด๋ฏธ์ง ์ํฅ ์ฒ๋",
info="๋ณํ์ ์์ฑํ๋ ค๋ฉด 1์ ์ฌ์ฉํ์ธ์",
minimum=0.0,
maximum=1.0,
step=0.05,
value=0.5,
)
result = gr.Image(label="๊ฒฐ๊ณผ", elem_id="result")
with gr.Accordion("๊ณ ๊ธ ์ค์ ", open=False):
negative_prompt = gr.Text(
label="๋ถ์ ์ ํ๋กฌํํธ",
max_lines=1,
placeholder="๋ถ์ ์ ํ๋กฌํํธ๋ฅผ ์
๋ ฅํ์ธ์",
)
seed = gr.Slider(
label="์๋",
minimum=0,
maximum=MAX_SEED,
step=1,
value=0,
)
randomize_seed = gr.Checkbox(label="์๋ ๋ฌด์์ํ", value=True)
with gr.Row():
width = gr.Slider(
label="๋๋น",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
height = gr.Slider(
label="๋์ด",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024,
)
with gr.Row():
guidance_scale = gr.Slider(
label="๊ฐ์ด๋์ค ์ฒ๋",
minimum=0.0,
maximum=10.0,
step=0.1,
value=5.0,
)
num_inference_steps = gr.Slider(
label="์ถ๋ก ๋จ๊ณ ์",
minimum=1,
maximum=100,
step=1,
value=100,
)
gr.Examples(
examples=examples,
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale],
outputs=[result, seed],
cache_examples="lazy"
)
gr.on(
triggers=[run_button.click, prompt.submit],
fn=infer,
inputs=[prompt, ip_adapter_image, ip_adapter_scale, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
outputs=[result, seed]
)
# Launch the app
demo.launch(share=True)
|