File size: 13,547 Bytes
06f0d78 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 |
import sys, os
import gradio as gr
## if kgen not exist
try:
import kgen
except:
GH_TOKEN = os.getenv("GITHUB_TOKEN")
git_url = f"https://{GH_TOKEN}@github.com/KohakuBlueleaf/TIPO-KGen@tipo"
## call pip install
os.system(f"pip install git+{git_url}")
import re
import random
from time import time
import torch
from transformers import set_seed
if sys.platform == "win32":
# dev env in windows, @spaces.GPU will cause problem
def GPU(**kwargs):
return lambda x: x
else:
from spaces import GPU
import kgen.models as models
import kgen.executor.tipo as tipo
from kgen.formatter import seperate_tags, apply_format
from kgen.generate import generate
from diff import load_model, encode_prompts
from meta import DEFAULT_NEGATIVE_PROMPT, DEFAULT_FORMAT
sdxl_pipe = load_model()
sdxl_pipe.text_encoder.to("cpu")
sdxl_pipe.text_encoder_2.to("cpu")
sdxl_pipe.vae.to("cpu")
sdxl_pipe.k_diffusion_model.to("cpu")
models.load_model("Amber-River/tipo", device="cuda", subfolder="500M-epoch3")
generate(max_new_tokens=4)
torch.cuda.empty_cache()
DEFAULT_TAGS = """
1girl, king halo (umamusume), umamusume,
ningen mame, ciloranko, ogipote, misu kasumi,
solo, leaning forward, sky,
masterpiece, absurdres, sensitive, newest
""".strip()
DEFAULT_NL = """
An illustration of a girl
""".strip()
def format_time(timing):
total = timing["total"]
generate_pass = timing["generate_pass"]
result = ""
result += f"""
### Process Time
| Total | {total:5.2f} sec / {generate_pass:5} Passes | {generate_pass/total:7.2f} Passes Per Second|
|-|-|-|
"""
if "generated_tokens" in timing:
total_generated_tokens = timing["generated_tokens"]
total_input_tokens = timing["input_tokens"]
if "generated_tokens" in timing and "total_sampling" in timing:
sampling_time = timing["total_sampling"] / 1000
process_time = timing["prompt_process"] / 1000
model_time = timing["total_eval"] / 1000
result += f"""| Process | {process_time:5.2f} sec / {total_input_tokens:5} Tokens | {total_input_tokens/process_time:7.2f} Tokens Per Second|
| Sampling | {sampling_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/sampling_time:7.2f} Tokens Per Second|
| Eval | {model_time:5.2f} sec / {total_generated_tokens:5} Tokens | {total_generated_tokens/model_time:7.2f} Tokens Per Second|
"""
if "generated_tokens" in timing:
result += f"""
### Processed Tokens:
* {total_input_tokens:} Input Tokens
* {total_generated_tokens:} Output Tokens
"""
return result
@GPU(duration=10)
@torch.no_grad()
def generate(
tags,
nl_prompt,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
):
torch.cuda.empty_cache()
default_format = DEFAULT_FORMAT[output_format]
tipo.BAN_TAGS = [t.strip() for t in black_list.split(",") if t.strip()]
generation_setting = {
"seed": seed,
"temperature": temp,
"top_p": top_p,
"min_p": min_p,
"top_k": top_k,
}
inputs = seperate_tags(tags.split(","))
if nl_prompt:
if "<|extended|>" in default_format:
inputs["extended"] = nl_prompt
elif "<|generated|>" in default_format:
inputs["generated"] = nl_prompt
input_prompt = apply_format(inputs, default_format)
if escape_brackets:
input_prompt = re.sub(r"([()\[\]])", r"\\\1", input_prompt)
meta, operations, general, nl_prompt = tipo.parse_tipo_request(
seperate_tags(tags.split(",")),
nl_prompt,
tag_length_target=target_length,
generate_extra_nl_prompt="<|generated|>" in default_format or not nl_prompt,
)
t0 = time()
for result, timing in tipo.tipo_runner_generator(
meta, operations, general, nl_prompt, **generation_setting
):
result = apply_format(result, default_format)
if escape_brackets:
result = re.sub(r"([()\[\]])", r"\\\1", result)
timing["total"] = time() - t0
yield result, input_prompt, format_time(timing)
torch.cuda.empty_cache()
@GPU(duration=20)
@torch.no_grad()
def generate_image(
seed,
prompt,
prompt2,
):
torch.cuda.empty_cache()
set_seed(seed)
sdxl_pipe.text_encoder.to("cuda")
sdxl_pipe.text_encoder_2.to("cuda")
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt2, DEFAULT_NEGATIVE_PROMPT)
)
sdxl_pipe.vae.to("cuda")
sdxl_pipe.k_diffusion_model.to("cuda")
print(prompt_embeds.device)
result2 = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
sdxl_pipe.text_encoder.to("cpu")
sdxl_pipe.text_encoder_2.to("cpu")
sdxl_pipe.vae.to("cpu")
sdxl_pipe.k_diffusion_model.to("cpu")
torch.cuda.empty_cache()
yield result2, None
set_seed(seed)
sdxl_pipe.text_encoder.to("cuda")
sdxl_pipe.text_encoder_2.to("cuda")
prompt_embeds, negative_prompt_embeds, pooled_embeds2, neg_pooled_embeds2 = (
encode_prompts(sdxl_pipe, prompt, DEFAULT_NEGATIVE_PROMPT)
)
sdxl_pipe.vae.to("cuda")
sdxl_pipe.k_diffusion_model.to("cuda")
result = sdxl_pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_embeds2,
negative_pooled_prompt_embeds=neg_pooled_embeds2,
num_inference_steps=24,
width=1024,
height=1024,
guidance_scale=6.0,
).images[0]
sdxl_pipe.text_encoder.to("cpu")
sdxl_pipe.text_encoder_2.to("cpu")
sdxl_pipe.vae.to("cpu")
sdxl_pipe.k_diffusion_model.to("cpu")
torch.cuda.empty_cache()
yield result2, result
if __name__ == "__main__":
with gr.Blocks(theme=gr.themes.Soft()) as demo:
with gr.Accordion("Introduction and Instructions", open=False):
gr.Markdown(
"""
## TIPO Demo
### What is this
TIPO is a tool to extend, generate, refine the input prompt for T2I models.
<br>It can work on both Danbooru tags and Natural Language. Which means you can use it on almost all the existed T2I models.
<br>You can take it as "pro max" version of [DTG](https://huggingface.co/KBlueLeaf/DanTagGen-delta-rev2)
### How to use this demo
1. Enter your tags(optional): put the desired tags into "danboru tags" box
2. Enter your NL Prompt(optional): put the desired natural language prompt into "Natural Language Prompt" box
3. Enter your black list(optional): put the desired black list into "black list" box
4. Adjust the settings: length, temp, top_p, min_p, top_k, seed ...
4. Click "TIPO" button: you will see refined prompt on "result" box
5. If you like the result, click "Generate Image From Result" button
* You will see 2 generated images, left one is based on your prompt, right one is based on refined prompt
* The backend is diffusers, there are no weighting mechanism, so Escape Brackets is default to False
### Why inference code is private? When will it be open sourced?
1. This model/tool is still under development, currently is early Alpha version.
2. I'm doing some research and projects based on this.
3. The model is released under CC-BY-NC-ND License currently. If you have interest, you can implement inference by yourself.
4. Once the project/research are done, I will open source all these models/codes with Apache2 license.
### Notification
**TIPO is NOT a T2I model. It is Prompt Gen, or, Text-to-Text model.
<br>The generated image is come from [Kohaku-XL-Zeta](https://huggingface.co/KBlueLeaf/Kohaku-XL-Zeta) model**
"""
)
with gr.Row():
with gr.Column(scale=5):
with gr.Row():
with gr.Column(scale=3):
tags_input = gr.TextArea(
label="Danbooru Tags",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_TAGS,
placeholder="Enter danbooru tags here",
)
nl_prompt_input = gr.Textbox(
label="Natural Language Prompt",
lines=7,
show_copy_button=True,
interactive=True,
value=DEFAULT_NL,
placeholder="Enter Natural Language Prompt here",
)
black_list = gr.TextArea(
label="Black List (seperated by comma)",
lines=4,
interactive=True,
value="monochrome",
placeholder="Enter tag/nl black list here",
)
with gr.Column(scale=2):
output_format = gr.Dropdown(
label="Output Format",
choices=list(DEFAULT_FORMAT.keys()),
value="Both, tag first (recommend)",
)
target_length = gr.Dropdown(
label="Target Length",
choices=["very_short", "short", "long", "very_long"],
value="long",
)
temp = gr.Slider(
label="Temp",
minimum=0.0,
maximum=1.5,
value=0.5,
step=0.05,
)
top_p = gr.Slider(
label="Top P",
minimum=0.0,
maximum=1.0,
value=0.95,
step=0.05,
)
min_p = gr.Slider(
label="Min P",
minimum=0.0,
maximum=0.2,
value=0.05,
step=0.01,
)
top_k = gr.Slider(
label="Top K", minimum=0, maximum=120, value=60, step=1
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=2147483647,
value=20090220,
step=1,
)
escape_brackets = gr.Checkbox(
label="Escape Brackets", value=False
)
submit = gr.Button("TIPO!", variant="primary")
with gr.Accordion("Speed statstics", open=False):
cost_time = gr.Markdown()
with gr.Column(scale=5):
result = gr.TextArea(
label="Result", lines=8, show_copy_button=True, interactive=False
)
input_prompt = gr.Textbox(
label="Input Prompt", lines=1, interactive=False, visible=False
)
gen_img = gr.Button(
"Generate Image from Result", variant="primary", interactive=False
)
with gr.Row():
with gr.Column():
img1 = gr.Image(label="Original Propmt", interactive=False)
with gr.Column():
img2 = gr.Image(label="Generated Prompt", interactive=False)
def generate_wrapper(*args):
yield "", "", "", gr.update(interactive=False),
for i in generate(*args):
yield *i, gr.update(interactive=False)
yield *i, gr.update(interactive=True)
submit.click(
generate_wrapper,
[
tags_input,
nl_prompt_input,
black_list,
temp,
output_format,
target_length,
top_p,
min_p,
top_k,
seed,
escape_brackets,
],
[
result,
input_prompt,
cost_time,
gen_img,
],
queue=True,
)
def generate_image_wrapper(seed, result, input_prompt):
for img1, img2 in generate_image(seed, result, input_prompt):
yield img1, img2, gr.update(interactive=False)
yield img1, img2, gr.update(interactive=True)
gen_img.click(
generate_image_wrapper,
[seed, result, input_prompt],
[img1, img2, submit],
queue=True,
)
gen_img.click(
lambda *args: gr.update(interactive=False),
None,
[submit],
queue=False,
)
demo.launch()
|