|
import os |
|
import yaml |
|
import torch |
|
import sys |
|
sys.path.append(os.path.abspath('./')) |
|
from inference.utils import * |
|
from train import WurstCoreB |
|
from gdf import DDPMSampler |
|
from train import WurstCore_t2i as WurstCoreC |
|
import numpy as np |
|
import random |
|
import argparse |
|
import gradio as gr |
|
import spaces |
|
from huggingface_hub import hf_hub_url |
|
import subprocess |
|
from huggingface_hub import hf_hub_download |
|
from transformers import pipeline |
|
|
|
|
|
translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--height', type=int, default=2560, help='image height') |
|
parser.add_argument('--width', type=int, default=5120, help='image width') |
|
parser.add_argument('--seed', type=int, default=123, help='random seed') |
|
parser.add_argument('--dtype', type=str, default='bf16', help='if bf16 does not work, change it to float32') |
|
parser.add_argument('--config_c', type=str, |
|
default='configs/training/t2i.yaml', help='config file for stage c, latent generation') |
|
parser.add_argument('--config_b', type=str, |
|
default='configs/inference/stage_b_1b.yaml', help='config file for stage b, latent decoding') |
|
parser.add_argument('--prompt', type=str, |
|
default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') |
|
parser.add_argument('--num_image', type=int, default=1, help='how many images generated') |
|
parser.add_argument('--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') |
|
parser.add_argument('--stage_a_tiled', action='store_true', help='whether or not to use tiled decoding for stage a to save memory') |
|
parser.add_argument('--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added parameter of UltraPixel') |
|
args = parser.parse_args() |
|
return args |
|
|
|
def clear_image(): |
|
return None |
|
|
|
def load_message(height, width, seed, prompt, args, stage_a_tiled): |
|
args.height = height |
|
args.width = width |
|
args.seed = seed |
|
args.prompt = prompt + ' rich detail, 4k, high quality' |
|
args.stage_a_tiled = stage_a_tiled |
|
return args |
|
|
|
def is_korean(text): |
|
return any('\uac00' <= char <= '\ud7a3' for char in text) |
|
|
|
def translate_if_korean(text): |
|
if is_korean(text): |
|
translated = translator(text, max_length=512)[0]['translation_text'] |
|
print(f"Translated from Korean: {text} -> {translated}") |
|
return translated |
|
return text |
|
|
|
@spaces.GPU(duration=120) |
|
def get_image(height, width, seed, prompt, cfg, timesteps, stage_a_tiled): |
|
global args |
|
|
|
|
|
prompt = translate_if_korean(prompt) |
|
|
|
args = load_message(height, width, seed, prompt, args, stage_a_tiled) |
|
torch.manual_seed(args.seed) |
|
random.seed(args.seed) |
|
np.random.seed(args.seed) |
|
dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float |
|
|
|
captions = [args.prompt] * args.num_image |
|
height, width = args.height, args.width |
|
batch_size = 1 |
|
height_lr, width_lr = get_target_lr_size(height / width, std_size=32) |
|
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) |
|
stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) |
|
|
|
|
|
extras.sampling_configs['cfg'] = 4 |
|
extras.sampling_configs['shift'] = 1 |
|
extras.sampling_configs['timesteps'] = 20 |
|
extras.sampling_configs['t_start'] = 1.0 |
|
extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) |
|
|
|
|
|
extras_b.sampling_configs['cfg'] = 1.1 |
|
extras_b.sampling_configs['shift'] = 1 |
|
extras_b.sampling_configs['timesteps'] = 10 |
|
extras_b.sampling_configs['t_start'] = 1.0 |
|
|
|
for _, caption in enumerate(captions): |
|
batch = {'captions': [caption] * batch_size} |
|
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) |
|
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) |
|
|
|
with torch.no_grad(): |
|
models.generator.cuda() |
|
print('STAGE C GENERATION***************************') |
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) |
|
|
|
models.generator.cpu() |
|
torch.cuda.empty_cache() |
|
|
|
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) |
|
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) |
|
conditions_b['effnet'] = sampled_c |
|
unconditions_b['effnet'] = torch.zeros_like(sampled_c) |
|
print('STAGE B + A DECODING***************************') |
|
|
|
with torch.cuda.amp.autocast(dtype=dtype): |
|
sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) |
|
|
|
torch.cuda.empty_cache() |
|
imgs = show_images(sampled) |
|
|
|
return imgs[0] |
|
|
|
css = """ |
|
footer { |
|
visibility: hidden; |
|
} |
|
""" |
|
|
|
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo: |
|
with gr.Column(elem_id="col-container"): |
|
gr.Markdown("<h1><center>UHD(MAX 5120 X 4096 Pixel) Image Gen</center></h1>") |
|
|
|
with gr.Row(): |
|
prompt = gr.Textbox( |
|
label="Text Prompt", |
|
show_label=False, |
|
max_lines=1, |
|
placeholder="Enter your prompt in Korean or English", |
|
container=False |
|
) |
|
polish_button = gr.Button("Submit", scale=0) |
|
|
|
output_img = gr.Image(label="Output Image", show_label=False) |
|
|
|
with gr.Accordion("Advanced Settings", open=False): |
|
seed = gr.Number( |
|
label="Random Seed", |
|
value=123, |
|
step=1, |
|
minimum=0, |
|
) |
|
|
|
with gr.Row(): |
|
width = gr.Slider( |
|
label="Width", |
|
minimum=1536, |
|
maximum=5120, |
|
step=32, |
|
value=4096 |
|
) |
|
|
|
height = gr.Slider( |
|
label="Height", |
|
minimum=1536, |
|
maximum=4096, |
|
step=32, |
|
value=2304 |
|
) |
|
|
|
with gr.Row(): |
|
cfg = gr.Slider( |
|
label="CFG", |
|
minimum=3, |
|
maximum=10, |
|
step=0.1, |
|
value=4 |
|
) |
|
|
|
timesteps = gr.Slider( |
|
label="Timesteps", |
|
minimum=10, |
|
maximum=50, |
|
step=1, |
|
value=20 |
|
) |
|
|
|
stage_a_tiled = gr.Checkbox(label="Stage_a_tiled", value=False) |
|
|
|
clear_button = gr.Button("Clear!") |
|
|
|
gr.Examples( |
|
examples=[ |
|
"A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", |
|
"๋ ๋ฎ์ธ ์ฐ๋งฅ์ ์ฅ์ํ ์ ๊ฒฝ, ํธ๋ฅธ ํ๋์ ๋ฐฐ๊ฒฝ์ผ๋ก ํ ๊ณ ์ํ ํธ์๊ฐ ์๋ ๋ชจ์ต", |
|
"The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", |
|
"์ค์จํฐ๋ฅผ ์
์ ์
์ด", |
|
"A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school uniform, standing under a cherry blossom tree with petals falling around her. The background shows a traditional Japanese school with cherry blossoms in full bloom.", |
|
"๊ณจ๋ ๋ฆฌํธ๋ฆฌ๋ฒ ๊ฐ์์ง๊ฐ ํธ๋ฅธ ์๋๋ฐญ์์ ๋นจ๊ฐ ๊ณต์ ์ซ๋ ๊ท์ฌ์ด ๋ชจ์ต", |
|
"A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.", |
|
], |
|
inputs=[prompt], |
|
outputs=[output_img], |
|
examples_per_page=5 |
|
) |
|
|
|
polish_button.click(get_image, inputs=[height, width, seed, prompt, cfg, timesteps, stage_a_tiled], outputs=output_img) |
|
polish_button.click(clear_image, inputs=[], outputs=output_img) |
|
|
|
def download_with_wget(url, save_path): |
|
try: |
|
subprocess.run(['wget', url, '-O', save_path], check=True) |
|
print(f"Downloaded to {save_path}") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Error downloading file: {e}") |
|
|
|
def download_model(): |
|
urls = [ |
|
'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors', |
|
'https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors', |
|
'https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors', |
|
'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors', |
|
'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors', |
|
] |
|
for file_url in urls: |
|
hf_hub_download(repo_id="stabilityai/stable-cascade", filename=file_url.split('/')[-1], local_dir='models') |
|
hf_hub_download(repo_id="roubaofeipi/UltraPixel", filename='ultrapixel_t2i.safetensors', local_dir='models') |
|
|
|
if __name__ == "__main__": |
|
args = parse_args() |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
download_model() |
|
config_file = args.config_c |
|
with open(config_file, "r", encoding="utf-8") as file: |
|
loaded_config = yaml.safe_load(file) |
|
|
|
core = WurstCoreC(config_dict=loaded_config, device=device, training=False) |
|
|
|
|
|
config_file_b = args.config_b |
|
with open(config_file_b, "r", encoding="utf-8") as file: |
|
config_file_b = yaml.safe_load(file) |
|
|
|
core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) |
|
|
|
extras = core.setup_extras_pre() |
|
models = core.setup_models(extras) |
|
models.generator.eval().requires_grad_(False) |
|
print("STAGE C READY") |
|
|
|
extras_b = core_b.setup_extras_pre() |
|
models_b = core_b.setup_models(extras_b, skip_clip=True) |
|
models_b = WurstCoreB.Models( |
|
**{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} |
|
) |
|
models_b.generator.bfloat16().eval().requires_grad_(False) |
|
print("STAGE B READY") |
|
|
|
pretrained_path = args.pretrained_path |
|
sdd = torch.load(pretrained_path, map_location='cpu') |
|
collect_sd = {} |
|
for k, v in sdd.items(): |
|
collect_sd[k[7:]] = v |
|
|
|
models.train_norm.load_state_dict(collect_sd) |
|
models.generator.eval() |
|
models.train_norm.eval() |
|
|
|
demo.launch(debug=True, share=True) |