visconet / app.py
soonyau's picture
Revert step number and guidance scale
48c9f06 verified
from share import *
import config
import os
import cv2
import einops
import gradio as gr
import numpy as np
import torch
import random
import re
from datetime import datetime
from glob import glob
import argparse
from pytorch_lightning import seed_everything
from torchvision.transforms import ToPILImage
from annotator.util import pad_image, resize_image, HWC3
from annotator.openpose import OpenposeDetector
from cldm.model import create_model, load_state_dict
from cldm.ddim_hacked import DDIMSampler
from pathlib import Path
from PIL import Image
from omegaconf import OmegaConf
from ldm.util import instantiate_from_config, log_txt_as_img
from visconet.segm import ATRSegmentCropper as SegmentCropper
from huggingface_hub import snapshot_download
# supply directory of visual prompt images
HF_REPO = 'soonyau/visconet'
GALLERY_PATH = Path('./fashion/')
WOMEN_GALLERY_PATH = GALLERY_PATH/'WOMEN'
MEN_GALLERY_PATH = GALLERY_PATH/'MEN'
DEMO = True
LOG_SAMPLES = False
APP_FILES_PATH = Path('./app_files')
VISCON_IMAGE_PATH = APP_FILES_PATH/'default_images'
LOG_PATH = APP_FILES_PATH/'logs'
SAMPLE_IMAGE_PATH = APP_FILES_PATH/'samples'
DEFAULT_CONTROL_SCALE = 1.0
SCALE_CONFIG = {
'Default': [DEFAULT_CONTROL_SCALE]*13,
'DeepFakes':[1.0, 1.0, 1.0,
1.0, 1.0, 1.0,
0.5, 0.5, 0.5,
0.0, 0.0, 0.0, 0.0,],
'Faithful':[1,1,1,
1,1,1,
1,1,0.5,
0.5,0.5,0,0],
'Painting':[0.0,0.0,0.0,
0.5,0.5,0.5,
0.5,0.5,0.5,
0.5,0,0,0],
'Pose': [0.0,0.0,0.0,
0.0,0.0,0.0,
0.0,0.0,0.5,
0.0,0.0,0,0],
'Texture Transfer': [1.0,1.0,1.0,
1.0,1.0,1.0,
0.5,0.0,0.5,
0.0,0.0,0,0]
}
DEFAULT_SCALE_CONFIG = 'Default'
ignore_style_list = ['headwear', 'accesories', 'shoes']
global device
global segmentor
global apply_openpose
global style_encoder
global model
global ddim_sampler
def convert_fname(long_name):
gender = 'MEN' if long_name[7:10] == 'MEN' else 'WOMEN'
input_list = long_name.replace('fashion','').split('___')
# Define a regular expression pattern to match the relevant parts of each input string
if gender == 'MEN':
pattern = r'MEN(\w+)id(\d+)_(\d)(\w+)'
else:
pattern = r'WOMEN(\w+)id(\d+)_(\d)(\w+)'
# Use a list comprehension to extract the matching substrings from each input string, and format them into the desired output format
output_list = [f'{gender}/{category}/id_{id_num[:8]}/{id_num[8:]}_{view_num}_{view_desc}' for (category, id_num, view_num, view_desc) in re.findall(pattern, ' '.join(input_list))]
# Print the resulting list of formatted strings
return [f +'.jpg' for f in output_list]
def fetch_deepfashion(deepfashion_names):
src_name, dst_name = convert_fname(deepfashion_names)
input_image = np.array(Image.open(image_root/src_name))
pose_image = np.array(Image.open(str(pose_root/dst_name)))
mask_image = Image.open(str(mask_root/dst_name).replace('.jpg','_mask.png'))
temp = src_name.replace('.jpg','').split('/')
lastfolder = temp.pop(-1).replace('_','/', 1)
style_folder = style_root/('/'.join(temp+[lastfolder]))
viscon_images = []
for style_name in style_names:
f_path = style_folder/f'{style_name}.jpg'
if os.path.exists(str(f_path)):
viscon_images.append(np.array(Image.open(f_path)))
else:
viscon_images.append(None)
return [input_image, pose_image, mask_image, *viscon_images]
def select_gallery_image(evt: gr.SelectData):
return evt.target.value[evt.index]['name']
def select_default_strength(strength_config):
return SCALE_CONFIG[strength_config]
def change_all_scales(scale):
return [float(scale)]*13
def encode_style_images(style_images):
style_embeddings = []
for style_name, style_image in zip(style_names, style_images):
if style_image == None:
style_image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8))
#style_image = style_image.resize((224,224))
style_image = style_encoder.preprocess(style_image).to(device)
style_emb = style_encoder.postprocess(style_encoder(style_image)[0])
style_embeddings.append(style_emb)
styles = torch.tensor(np.array(style_embeddings)).squeeze(-2).unsqueeze(0).float().to(device)
return styles
def save_viscon_images(*viscon_images):
ret_images = []
for image, name in zip(viscon_images, style_names):
fname = str(VISCON_IMAGE_PATH/name)+'.jpg'
if image:
image = image.resize((224,224))
if os.path.exists(fname):
os.remove(fname)
image.save(fname)
ret_images.append(image)
return ret_images
def extract_pose_mask(input_image, detect_resolution,
ignore_head=True, ignore_hair=False):
# skeleton
input_image = pad_image(input_image, min_aspect_ratio=0.625)
detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution), hand=True)
detected_map = HWC3(detected_map)
# human mask
cropped = segmentor(input_image, ignore_head=ignore_head, ignore_hair=ignore_hair)
mask = cropped['human_mask']
mask = Image.fromarray(np.array(mask*255, dtype=np.uint8), mode='L')
return [detected_map, mask]
def extract_fashion(input_image):
# style images
cropped = segmentor(input_image)
cropped_images = []
for style_name in style_names:
if style_name in cropped and style_name not in ignore_style_list:
cropped_images.append(cropped[style_name])
else:
cropped_images.append(None)
return [*cropped_images]
def get_image_files(image_path, ret_image=True, exts=['.jpg','.jpeg','.png']):
images = []
for ext in exts:
images += [x for x in glob(str(Path(image_path)/f'*{ext}'))]
if ret_image:
images = [Image.open(x) for x in images]
return images
def log_sample(seed, results, prompt, skeleton_image, mask_image, control_scales, *viscon_images):
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
log_dir = LOG_PATH/time_str
os.makedirs(str(log_dir), exist_ok=True)
# save result
concat = np.hstack((skeleton_image, *results))
Image.fromarray(skeleton_image).save(str(log_dir/'skeleton.jpg'))
Image.fromarray(mask_image).save(str(log_dir/'mask.png'))
for i, result in enumerate(results):
Image.fromarray(result).save(str(log_dir/f'result_{i}.jpg'))
# save text
with open(str(log_dir/'info.txt'),'w') as f:
f.write(f'prompt: {prompt} \n')
f.write(f'seed: {seed}\n')
control_str = [str(x) for x in control_scales]
f.write(','.join(control_str) + '\n')
# save vison images
for style_name, style_image in zip(style_names, viscon_images):
if style_image is not None:
style_image.save(str(log_dir/f'{style_name}.jpg'))
def process(prompt, a_prompt, n_prompt, num_samples,
ddim_steps, scale, seed, eta, mask_image, pose_image,
c12, c11, c10, c9, c8, c7, c6, c5, c4, c3, c2, c1, c0,
*viscon_images):
with torch.no_grad():
control_scales = [c12, c11, c10, c9, c8, c7, c6, c5, c4, c3, c2, c1, c0]
mask = torch.tensor(mask_image.mean(-1)/255.,dtype=torch.float) #(512,512), [0,1]
mask = mask.unsqueeze(0).to(device) # (1, 512, 512)
style_emb = encode_style_images(viscon_images)
# fix me
detected_map = HWC3(pose_image)
#detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST)
H, W, C = detected_map.shape
control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0
control = torch.stack([control for _ in range(num_samples)], dim=0)
control = einops.rearrange(control, 'b h w c -> b c h w').clone()
if seed == -1:
seed = random.randint(0, 65535)
seed_everything(seed)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
new_style_shape = [num_samples] + [1] * (len(style_emb.shape)-1)
cond = {"c_concat": [control],
"c_crossattn": [style_emb.repeat(new_style_shape)],
"c_text": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)],
'c_concat_mask': [mask.repeat(num_samples, 1, 1, 1)]}
un_cond = {"c_concat": [control],
"c_crossattn": [torch.zeros_like(style_emb).repeat(new_style_shape)],
"c_text":[model.get_learned_conditioning([n_prompt] * num_samples)],
'c_concat_mask': [torch.zeros_like(mask).repeat(num_samples, 1, 1, 1)]}
shape = (4, H // 8, W // 8)
if config.save_memory:
model.low_vram_shift(is_diffusing=True)
model.control_scales = control_scales
samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
shape, cond, verbose=False, eta=eta,
unconditional_guidance_scale=scale,
unconditional_conditioning=un_cond)
if config.save_memory:
model.low_vram_shift(is_diffusing=False)
x_samples = model.decode_first_stage(samples)
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
results = [x_samples[i] for i in range(num_samples)]
if LOG_SAMPLES:
log_sample(seed, results, prompt, detected_map, mask_image, control_scales, *viscon_images)
return results
def get_image(name, file_ext='.jpg'):
fname = str(VISCON_IMAGE_PATH/name)+file_ext
if not os.path.exists(fname):
return None
return Image.open(fname)
def get_image_numpy(name, file_ext='.png'):
fname = str(VISCON_IMAGE_PATH/name)+file_ext
if not os.path.exists(fname):
return None
return np.array(Image.open(fname))
def create_app():
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.Markdown("## ViscoNet: Visual ControlNet with Human Pose and Fashion <br> [Video tutorial](https://youtu.be/85NyIuLeV00)")
with gr.Row():
with gr.Column():
with gr.Accordion("Get pose and mask", open=True):
with gr.Row():
input_image = gr.Image(source='upload', type="numpy", label='input image', value=np.array(get_image_numpy('ref')))
pose_image = gr.Image(source='upload', type="numpy", label='pose', value=np.array(get_image_numpy('pose')))
mask_image = gr.Image(source='upload', type="numpy", label='mask', value=np.array(get_image_numpy('mask')))
with gr.Accordion("Human Pose Samples", open=False):
with gr.Tab('Female'):
samples = get_image_files(str(SAMPLE_IMAGE_PATH/'pose/WOMEN/'))
female_pose_gallery = gr.Gallery(label='pose', show_label=False, value=samples).style(grid=3, height='auto')
with gr.Tab('Male'):
samples = get_image_files(str(SAMPLE_IMAGE_PATH/'pose/MEN/'))
male_pose_gallery = gr.Gallery(label='pose', show_label=False, value=samples).style(grid=3, height='auto')
with gr.Row():
#pad_checkbox = gr.Checkbox(label='Pad pose to square', value=True)
ignorehead_checkbox = gr.Checkbox(label='Ignore face in masking (for faceswap with text)', value=False)
ignorehair_checkbox = gr.Checkbox(label='Ignore hair in masking', value=False, visible=True)
with gr.Row():
#ignore_head_checkbox = gr.Checkbox(label='Ignore head', value=False)
get_pose_button = gr.Button(label="Get pose", value='Get pose')
get_fashion_button = gr.Button(label="Get visual", value='Get visual prompt')
with gr.Accordion("Visual Conditions", open=True):
gr.Markdown('Drag-and-drop, or click from samples below.')
with gr.Column():
viscon_images = []
viscon_images_names2index = {}
viscon_len = len(style_names)
v_idx = 0
with gr.Row():
for _ in range(8):
viscon_name = style_names[v_idx]
vis = False if viscon_name in ignore_style_list else True
viscon_images.append(gr.Image(source='upload', type="pil", min_height=112, min_width=112, label=viscon_name, value=get_image(viscon_name), visible=vis))
viscon_images_names2index[viscon_name] = v_idx
v_idx += 1
viscon_button = gr.Button(value='Save as Default',visible=False if DEMO else True)
viscon_galleries = []
with gr.Accordion("Virtual Try-on", open=False):
with gr.Column():
#with gr.Accordion("Female", open=False):
with gr.Tab('Female'):
for garment, number in zip(['face', 'hair', 'top', 'bottom', 'outer'], [50, 150, 500, 500, 250]):
with gr.Tab(garment):
samples = []
if WOMEN_GALLERY_PATH and os.path.exists(WOMEN_GALLERY_PATH):
samples = glob(os.path.join(WOMEN_GALLERY_PATH, f'**/{garment}.jpg'), recursive=True)
samples = random.choices(samples, k=number)
viscon_gallery = gr.Gallery(label='hair', allow_preview=False, show_label=False, value=samples).style(grid=4, height='auto')
viscon_galleries.append({'component':viscon_gallery, 'inputs':[garment]})
#with gr.Accordion("Male", open=False):
with gr.Tab('Male'):
for garment, number in zip(['face','hair', 'top', 'bottom', 'outer'], [50, 150, 500, 500, 250]):
with gr.Tab(garment):
samples = []
if MEN_GALLERY_PATH and os.path.exists(MEN_GALLERY_PATH):
samples = glob(os.path.join(MEN_GALLERY_PATH, f'**/{garment}.jpg'), recursive=True)
samples = random.choices(samples, k=number)
viscon_gallery = gr.Gallery(label='hair', allow_preview=False, show_label=False, value=samples).style(grid=4, height='auto')
viscon_galleries.append({'component':viscon_gallery, 'inputs':[garment]})
with gr.Column():
result_gallery = gr.Gallery(label='Output', show_label=False, show_download_button=True, elem_id="gallery").style(grid=1, height='auto')
with gr.Row():
max_samples = 8 if not DEMO else 4
num_samples = gr.Slider(label="Images", minimum=1, maximum=max_samples, value=1, step=1)
scale_all = gr.Slider(label=f'Control Strength', minimum=0, maximum=1, value=DEFAULT_CONTROL_SCALE, step=0.05)
seed = gr.Slider(label="Seed (-1 for random)", minimum=-1, maximum=2147483647, step=1, value=1561194236)#randomize=True) #value=1561194234)
if not DEMO:
DF_DEMO = 'fashionWOMENTees_Tanksid0000762403_1front___fashionWOMENTees_Tanksid0000762403_1front'
DF_EVAL = 'fashionWOMENBlouses_Shirtsid0000035501_1front___fashionWOMENBlouses_Shirtsid0000035501_1front'
DF_RESULT ="fashionWOMENTees_Tanksid0000796209_1front___fashionWOMENTees_Tanksid0000796209_2side"
deepfashion_names = gr.Textbox(label='Deepfashion name', value=DF_EVAL)
gr.Markdown("Default config reconstruct image faithful to pose, mask and visual condition. Reduce control strength to tip balance towards text prompt for more creativity.")
prompt = gr.Textbox(label="Text Prompt", value="")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
with gr.Accordion("Control Strength Scaling", open=False):
gr.Markdown("smaller value for stronger textual influence. c12 is highest spatial resolution controlling textures")
strength_select = gr.Dropdown(list(SCALE_CONFIG.keys()), label='strength settings', value=DEFAULT_SCALE_CONFIG)
scale_values = SCALE_CONFIG[DEFAULT_SCALE_CONFIG]
control_scales = []
c_idx = 12
with gr.Accordion("Advanced settings", open=False):
with gr.Row():
for _ in range(3):
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05))
c_idx -= 1
with gr.Row():
for _ in range(3):
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05))
c_idx -= 1
with gr.Row():
for _ in range(3):
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05))
c_idx -= 1
with gr.Row():
for _ in range(4):
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05))
c_idx -= 1
with gr.Row():
detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=768, value=512, step=1)
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1)
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=12.0, step=0.1)
eta = gr.Number(label="eta (DDIM)", value=0.0, visible=False)
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
n_prompt = gr.Textbox(label="Negative Prompt",
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, sunglasses, hat')
female_pose_gallery.select(fn=select_gallery_image, inputs=None, outputs=input_image)
male_pose_gallery.select(fn=select_gallery_image, inputs=None, outputs=input_image)
for vision_gallery in viscon_galleries:
viscon_idx = viscon_images_names2index[vision_gallery['inputs'][0]]
vision_gallery['component'].select(fn=select_gallery_image, inputs=None,
outputs=viscon_images[viscon_idx])
ips = [prompt, a_prompt, n_prompt, num_samples, ddim_steps, scale, seed, eta, mask_image, pose_image,
*control_scales, *viscon_images]
run_button.click(fn=process, inputs=ips, outputs=[result_gallery])
prompt.submit(fn=process, inputs=ips, outputs=[result_gallery])
get_pose_button.click(fn=extract_pose_mask, inputs=[input_image, detect_resolution,
ignorehead_checkbox, ignorehair_checkbox],
outputs=[pose_image, mask_image])
get_fashion_button.click(fn=extract_fashion, inputs=input_image, outputs=[*viscon_images])
viscon_button.click(fn=save_viscon_images, inputs=[*viscon_images], outputs=[*viscon_images])
strength_select.select(fn=select_default_strength, inputs=[strength_select], outputs=[*control_scales])
scale_all.release(fn=change_all_scales, inputs=[scale_all], outputs=[*control_scales])
if not DEMO:
deepfashion_names.submit(fn=fetch_deepfashion, inputs=[deepfashion_names], outputs=[input_image, pose_image, mask_image, *viscon_images])
return block
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('--gpu', type=int, default=0, help='GPU id')
parser.add_argument('--config', type=str, default='./configs/visconet_v1.yaml')
parser.add_argument('--ckpt', type=str, default='./models/visconet_v1.pth')
parser.add_argument('--public_link', action='store_true', default='', help='Create public link')
args = parser.parse_args()
global device
global segmentor
global apply_openpose
global style_encoder
global model
global ddim_sampler
device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu'
config_file = args.config
model_ckpt = args.ckpt
proj_config = OmegaConf.load(config_file)
style_names = proj_config.dataset.train.params.style_names
data_root = Path(proj_config.dataset.train.params.image_root)
image_root = data_root/proj_config.dataset.train.params.image_dir
style_root = data_root/proj_config.dataset.train.params.style_dir
pose_root = data_root/proj_config.dataset.train.params.pose_dir
mask_root = data_root/proj_config.dataset.train.params.mask_dir
segmentor = SegmentCropper()
apply_openpose = OpenposeDetector()
if not os.path.exists(model_ckpt):
snapshot_download(repo_id=HF_REPO, local_dir='./models',
allow_patterns=os.path.basename(model_ckpt))
style_encoder = instantiate_from_config(proj_config.model.style_embedding_config).to(device)
model = create_model(config_file).cpu()
model.load_state_dict(load_state_dict(model_ckpt, location=device))
model = model.to(device)
model.cond_stage_model.device = device
ddim_sampler = DDIMSampler(model)
if not GALLERY_PATH.exists():
zip_name = 'fashion.zip'
snapshot_download(repo_id=HF_REPO, allow_patterns=zip_name, local_dir='.')
from zipfile import ZipFile
with ZipFile(zip_name, 'r') as zip_ref:
zip_ref.extractall('.')
os.remove(zip_name)
# Calling the main function with parsed arguments
block = create_app()
block.launch(share=args.public_link)