import os import yaml import torch import sys sys.path.append(os.path.abspath('./')) from inference.utils import * from train import WurstCoreB from gdf import DDPMSampler from train import WurstCore_t2i as WurstCoreC import numpy as np import random import argparse import gradio as gr import spaces from huggingface_hub import hf_hub_url import subprocess from huggingface_hub import hf_hub_download from transformers import pipeline # Initialize the translation pipeline translator = pipeline("translation", model="Helsinki-NLP/opus-mt-ko-en") def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--height', type=int, default=2560, help='image height') parser.add_argument('--width', type=int, default=5120, help='image width') parser.add_argument('--seed', type=int, default=123, help='random seed') parser.add_argument('--dtype', type=str, default='bf16', help='if bf16 does not work, change it to float32') parser.add_argument('--config_c', type=str, default='configs/training/t2i.yaml', help='config file for stage c, latent generation') parser.add_argument('--config_b', type=str, default='configs/inference/stage_b_1b.yaml', help='config file for stage b, latent decoding') parser.add_argument('--prompt', type=str, default='A photo-realistic image of a west highland white terrier in the garden, high quality, detail rich, 8K', help='text prompt') parser.add_argument('--num_image', type=int, default=1, help='how many images generated') parser.add_argument('--output_dir', type=str, default='figures/output_results/', help='output directory for generated image') parser.add_argument('--stage_a_tiled', action='store_true', help='whether or not to use tiled decoding for stage a to save memory') parser.add_argument('--pretrained_path', type=str, default='models/ultrapixel_t2i.safetensors', help='pretrained path of newly added parameter of UltraPixel') args = parser.parse_args() return args def clear_image(): return None def load_message(height, width, seed, prompt, args, stage_a_tiled): args.height = height args.width = width args.seed = seed args.prompt = prompt + ' rich detail, 4k, high quality' args.stage_a_tiled = stage_a_tiled return args def is_korean(text): return any('\uac00' <= char <= '\ud7a3' for char in text) def translate_if_korean(text): if is_korean(text): translated = translator(text, max_length=512)[0]['translation_text'] print(f"Translated from Korean: {text} -> {translated}") return translated return text @spaces.GPU(duration=120) def get_image(height, width, seed, prompt, cfg, timesteps, stage_a_tiled): global args # Translate the prompt if it's in Korean prompt = translate_if_korean(prompt) args = load_message(height, width, seed, prompt, args, stage_a_tiled) torch.manual_seed(args.seed) random.seed(args.seed) np.random.seed(args.seed) dtype = torch.bfloat16 if args.dtype == 'bf16' else torch.float captions = [args.prompt] * args.num_image height, width = args.height, args.width batch_size = 1 height_lr, width_lr = get_target_lr_size(height / width, std_size=32) stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size) stage_c_latent_shape_lr, stage_b_latent_shape_lr = calculate_latent_sizes(height_lr, width_lr, batch_size=batch_size) # Stage C Parameters extras.sampling_configs['cfg'] = 4 extras.sampling_configs['shift'] = 1 extras.sampling_configs['timesteps'] = 20 extras.sampling_configs['t_start'] = 1.0 extras.sampling_configs['sampler'] = DDPMSampler(extras.gdf) # Stage B Parameters extras_b.sampling_configs['cfg'] = 1.1 extras_b.sampling_configs['shift'] = 1 extras_b.sampling_configs['timesteps'] = 10 extras_b.sampling_configs['t_start'] = 1.0 for _, caption in enumerate(captions): batch = {'captions': [caption] * batch_size} conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) with torch.no_grad(): models.generator.cuda() print('STAGE C GENERATION***************************') with torch.cuda.amp.autocast(dtype=dtype): sampled_c = generation_c(batch, models, extras, core, stage_c_latent_shape, stage_c_latent_shape_lr, device) models.generator.cpu() torch.cuda.empty_cache() conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False) unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True) conditions_b['effnet'] = sampled_c unconditions_b['effnet'] = torch.zeros_like(sampled_c) print('STAGE B + A DECODING***************************') with torch.cuda.amp.autocast(dtype=dtype): sampled = decode_b(conditions_b, unconditions_b, models_b, stage_b_latent_shape, extras_b, device, stage_a_tiled=args.stage_a_tiled) torch.cuda.empty_cache() imgs = show_images(sampled) return imgs[0] css = """ footer { visibility: hidden; } """ with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: with gr.Column(elem_id="col-container"): gr.Markdown("

초고해상도 UHD(최대 5120 X 4096 픽셀) 이미지 생성

") with gr.Row(): prompt = gr.Textbox( label="Text Prompt (한글 또는 영어로 입력하세요)", show_label=False, max_lines=1, placeholder="프롬프트를 입력하세요 (Enter your prompt in Korean or English)", container=False ) polish_button = gr.Button("제출! (Submit!)", scale=0) output_img = gr.Image(label="Output Image", show_label=False) with gr.Accordion("Advanced Settings", open=False): seed = gr.Number( label="Random Seed", value=123, step=1, minimum=0, ) with gr.Row(): width = gr.Slider( label="Width", minimum=1536, maximum=5120, step=32, value=4096 ) height = gr.Slider( label="Height", minimum=1536, maximum=4096, step=32, value=2304 ) with gr.Row(): cfg = gr.Slider( label="CFG", minimum=3, maximum=10, step=0.1, value=4 ) timesteps = gr.Slider( label="Timesteps", minimum=10, maximum=50, step=1, value=20 ) stage_a_tiled = gr.Checkbox(label="Stage_a_tiled", value=False) clear_button = gr.Button("Clear!") gr.Examples( examples=[ "A detailed view of a blooming magnolia tree, with large, white flowers and dark green leaves, set against a clear blue sky.", "눈 덮인 산맥의 장엄한 전경, 푸른 하늘을 배경으로 한 고요한 호수가 있는 모습", "The image features a snow-covered mountain range with a large, snow-covered mountain in the background. The mountain is surrounded by a forest of trees, and the sky is filled with clouds. The scene is set during the winter season, with snow covering the ground and the trees.", "스웨터를 입은 악어", "A vibrant anime scene of a young girl with long, flowing pink hair, big sparkling blue eyes, and a school uniform, standing under a cherry blossom tree with petals falling around her. The background shows a traditional Japanese school with cherry blossoms in full bloom.", "골든 리트리버 강아지가 푸른 잔디밭에서 빨간 공을 쫓는 귀여운 모습", "A cozy, rustic log cabin nestled in a snow-covered forest, with smoke rising from the stone chimney, warm lights glowing from the windows, and a path of footprints leading to the front door.", "캐나다 밴프 국립공원의 아름다운 풍경, 청록색 호수와 눈 덮인 산들, 울창한 소나무 숲이 어우러진 모습", "귀여운 시츄가 욕조에서 목욕하는 모습, 거품에 둘러싸인 채 살짝 젖은 모습으로 카메라를 바라보고 있음", ], inputs=[prompt], outputs=[output_img], examples_per_page=5 ) polish_button.click(get_image, inputs=[height, width, seed, prompt, cfg, timesteps, stage_a_tiled], outputs=output_img) polish_button.click(clear_image, inputs=[], outputs=output_img) def download_with_wget(url, save_path): try: subprocess.run(['wget', url, '-O', save_path], check=True) print(f"Downloaded to {save_path}") except subprocess.CalledProcessError as e: print(f"Error downloading file: {e}") def download_model(): urls = [ 'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_a.safetensors', 'https://huggingface.co/stabilityai/StableWurst/resolve/main/previewer.safetensors', 'https://huggingface.co/stabilityai/StableWurst/resolve/main/effnet_encoder.safetensors', 'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_b_lite_bf16.safetensors', 'https://huggingface.co/stabilityai/StableWurst/resolve/main/stage_c_bf16.safetensors', ] for file_url in urls: hf_hub_download(repo_id="stabilityai/stable-cascade", filename=file_url.split('/')[-1], local_dir='models') hf_hub_download(repo_id="roubaofeipi/UltraPixel", filename='ultrapixel_t2i.safetensors', local_dir='models') if __name__ == "__main__": args = parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") download_model() config_file = args.config_c with open(config_file, "r", encoding="utf-8") as file: loaded_config = yaml.safe_load(file) core = WurstCoreC(config_dict=loaded_config, device=device, training=False) # SETUP STAGE B config_file_b = args.config_b with open(config_file_b, "r", encoding="utf-8") as file: config_file_b = yaml.safe_load(file) core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False) extras = core.setup_extras_pre() models = core.setup_models(extras) models.generator.eval().requires_grad_(False) print("STAGE C READY") extras_b = core_b.setup_extras_pre() models_b = core_b.setup_models(extras_b, skip_clip=True) models_b = WurstCoreB.Models( **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model} ) models_b.generator.bfloat16().eval().requires_grad_(False) print("STAGE B READY") pretrained_path = args.pretrained_path sdd = torch.load(pretrained_path, map_location='cpu') collect_sd = {} for k, v in sdd.items(): collect_sd[k[7:]] = v models.train_norm.load_state_dict(collect_sd) models.generator.eval() models.train_norm.eval() demo.launch(debug=True, share=True)