Spaces:
Sleeping
Sleeping
from share import * | |
import config | |
import os | |
import cv2 | |
import einops | |
import gradio as gr | |
import numpy as np | |
import torch | |
import random | |
import re | |
from datetime import datetime | |
from glob import glob | |
import argparse | |
from pytorch_lightning import seed_everything | |
from torchvision.transforms import ToPILImage | |
from annotator.util import pad_image, resize_image, HWC3 | |
from annotator.openpose import OpenposeDetector | |
from cldm.model import create_model, load_state_dict | |
from cldm.ddim_hacked import DDIMSampler | |
from pathlib import Path | |
from PIL import Image | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config, log_txt_as_img | |
from visconet.segm import ATRSegmentCropper as SegmentCropper | |
from huggingface_hub import snapshot_download | |
# supply directory of visual prompt images | |
HF_REPO = 'soonyau/visconet' | |
GALLERY_PATH = Path('./fashion/') | |
WOMEN_GALLERY_PATH = GALLERY_PATH/'WOMEN' | |
MEN_GALLERY_PATH = GALLERY_PATH/'MEN' | |
DEMO = True | |
LOG_SAMPLES = False | |
APP_FILES_PATH = Path('./app_files') | |
VISCON_IMAGE_PATH = APP_FILES_PATH/'default_images' | |
LOG_PATH = APP_FILES_PATH/'logs' | |
SAMPLE_IMAGE_PATH = APP_FILES_PATH/'samples' | |
DEFAULT_CONTROL_SCALE = 1.0 | |
SCALE_CONFIG = { | |
'Default': [DEFAULT_CONTROL_SCALE]*13, | |
'DeepFakes':[1.0, 1.0, 1.0, | |
1.0, 1.0, 1.0, | |
0.5, 0.5, 0.5, | |
0.0, 0.0, 0.0, 0.0,], | |
'Faithful':[1,1,1, | |
1,1,1, | |
1,1,0.5, | |
0.5,0.5,0,0], | |
'Painting':[0.0,0.0,0.0, | |
0.5,0.5,0.5, | |
0.5,0.5,0.5, | |
0.5,0,0,0], | |
'Pose': [0.0,0.0,0.0, | |
0.0,0.0,0.0, | |
0.0,0.0,0.5, | |
0.0,0.0,0,0], | |
'Texture Transfer': [1.0,1.0,1.0, | |
1.0,1.0,1.0, | |
0.5,0.0,0.5, | |
0.0,0.0,0,0] | |
} | |
DEFAULT_SCALE_CONFIG = 'Default' | |
ignore_style_list = ['headwear', 'accesories', 'shoes'] | |
global device | |
global segmentor | |
global apply_openpose | |
global style_encoder | |
global model | |
global ddim_sampler | |
def convert_fname(long_name): | |
gender = 'MEN' if long_name[7:10] == 'MEN' else 'WOMEN' | |
input_list = long_name.replace('fashion','').split('___') | |
# Define a regular expression pattern to match the relevant parts of each input string | |
if gender == 'MEN': | |
pattern = r'MEN(\w+)id(\d+)_(\d)(\w+)' | |
else: | |
pattern = r'WOMEN(\w+)id(\d+)_(\d)(\w+)' | |
# Use a list comprehension to extract the matching substrings from each input string, and format them into the desired output format | |
output_list = [f'{gender}/{category}/id_{id_num[:8]}/{id_num[8:]}_{view_num}_{view_desc}' for (category, id_num, view_num, view_desc) in re.findall(pattern, ' '.join(input_list))] | |
# Print the resulting list of formatted strings | |
return [f +'.jpg' for f in output_list] | |
def fetch_deepfashion(deepfashion_names): | |
src_name, dst_name = convert_fname(deepfashion_names) | |
input_image = np.array(Image.open(image_root/src_name)) | |
pose_image = np.array(Image.open(str(pose_root/dst_name))) | |
mask_image = Image.open(str(mask_root/dst_name).replace('.jpg','_mask.png')) | |
temp = src_name.replace('.jpg','').split('/') | |
lastfolder = temp.pop(-1).replace('_','/', 1) | |
style_folder = style_root/('/'.join(temp+[lastfolder])) | |
viscon_images = [] | |
for style_name in style_names: | |
f_path = style_folder/f'{style_name}.jpg' | |
if os.path.exists(str(f_path)): | |
viscon_images.append(np.array(Image.open(f_path))) | |
else: | |
viscon_images.append(None) | |
return [input_image, pose_image, mask_image, *viscon_images] | |
def select_gallery_image(evt: gr.SelectData): | |
return evt.target.value[evt.index]['name'] | |
def select_default_strength(strength_config): | |
return SCALE_CONFIG[strength_config] | |
def change_all_scales(scale): | |
return [float(scale)]*13 | |
def encode_style_images(style_images): | |
style_embeddings = [] | |
for style_name, style_image in zip(style_names, style_images): | |
if style_image == None: | |
style_image = Image.fromarray(np.zeros((224, 224, 3), dtype=np.uint8)) | |
#style_image = style_image.resize((224,224)) | |
style_image = style_encoder.preprocess(style_image).to(device) | |
style_emb = style_encoder.postprocess(style_encoder(style_image)[0]) | |
style_embeddings.append(style_emb) | |
styles = torch.tensor(np.array(style_embeddings)).squeeze(-2).unsqueeze(0).float().to(device) | |
return styles | |
def save_viscon_images(*viscon_images): | |
ret_images = [] | |
for image, name in zip(viscon_images, style_names): | |
fname = str(VISCON_IMAGE_PATH/name)+'.jpg' | |
if image: | |
image = image.resize((224,224)) | |
if os.path.exists(fname): | |
os.remove(fname) | |
image.save(fname) | |
ret_images.append(image) | |
return ret_images | |
def extract_pose_mask(input_image, detect_resolution, | |
ignore_head=True, ignore_hair=False): | |
# skeleton | |
input_image = pad_image(input_image, min_aspect_ratio=0.625) | |
detected_map, _ = apply_openpose(resize_image(input_image, detect_resolution), hand=True) | |
detected_map = HWC3(detected_map) | |
# human mask | |
cropped = segmentor(input_image, ignore_head=ignore_head, ignore_hair=ignore_hair) | |
mask = cropped['human_mask'] | |
mask = Image.fromarray(np.array(mask*255, dtype=np.uint8), mode='L') | |
return [detected_map, mask] | |
def extract_fashion(input_image): | |
# style images | |
cropped = segmentor(input_image) | |
cropped_images = [] | |
for style_name in style_names: | |
if style_name in cropped and style_name not in ignore_style_list: | |
cropped_images.append(cropped[style_name]) | |
else: | |
cropped_images.append(None) | |
return [*cropped_images] | |
def get_image_files(image_path, ret_image=True, exts=['.jpg','.jpeg','.png']): | |
images = [] | |
for ext in exts: | |
images += [x for x in glob(str(Path(image_path)/f'*{ext}'))] | |
if ret_image: | |
images = [Image.open(x) for x in images] | |
return images | |
def log_sample(seed, results, prompt, skeleton_image, mask_image, control_scales, *viscon_images): | |
time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
log_dir = LOG_PATH/time_str | |
os.makedirs(str(log_dir), exist_ok=True) | |
# save result | |
concat = np.hstack((skeleton_image, *results)) | |
Image.fromarray(skeleton_image).save(str(log_dir/'skeleton.jpg')) | |
Image.fromarray(mask_image).save(str(log_dir/'mask.png')) | |
for i, result in enumerate(results): | |
Image.fromarray(result).save(str(log_dir/f'result_{i}.jpg')) | |
# save text | |
with open(str(log_dir/'info.txt'),'w') as f: | |
f.write(f'prompt: {prompt} \n') | |
f.write(f'seed: {seed}\n') | |
control_str = [str(x) for x in control_scales] | |
f.write(','.join(control_str) + '\n') | |
# save vison images | |
for style_name, style_image in zip(style_names, viscon_images): | |
if style_image is not None: | |
style_image.save(str(log_dir/f'{style_name}.jpg')) | |
def process(prompt, a_prompt, n_prompt, num_samples, | |
ddim_steps, scale, seed, eta, mask_image, pose_image, | |
c12, c11, c10, c9, c8, c7, c6, c5, c4, c3, c2, c1, c0, | |
*viscon_images): | |
with torch.no_grad(): | |
control_scales = [c12, c11, c10, c9, c8, c7, c6, c5, c4, c3, c2, c1, c0] | |
mask = torch.tensor(mask_image.mean(-1)/255.,dtype=torch.float) #(512,512), [0,1] | |
mask = mask.unsqueeze(0).to(device) # (1, 512, 512) | |
style_emb = encode_style_images(viscon_images) | |
# fix me | |
detected_map = HWC3(pose_image) | |
#detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_NEAREST) | |
H, W, C = detected_map.shape | |
control = torch.from_numpy(detected_map.copy()).float().to(device) / 255.0 | |
control = torch.stack([control for _ in range(num_samples)], dim=0) | |
control = einops.rearrange(control, 'b h w c -> b c h w').clone() | |
if seed == -1: | |
seed = random.randint(0, 65535) | |
seed_everything(seed) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
new_style_shape = [num_samples] + [1] * (len(style_emb.shape)-1) | |
cond = {"c_concat": [control], | |
"c_crossattn": [style_emb.repeat(new_style_shape)], | |
"c_text": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], | |
'c_concat_mask': [mask.repeat(num_samples, 1, 1, 1)]} | |
un_cond = {"c_concat": [control], | |
"c_crossattn": [torch.zeros_like(style_emb).repeat(new_style_shape)], | |
"c_text":[model.get_learned_conditioning([n_prompt] * num_samples)], | |
'c_concat_mask': [torch.zeros_like(mask).repeat(num_samples, 1, 1, 1)]} | |
shape = (4, H // 8, W // 8) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=True) | |
model.control_scales = control_scales | |
samples, _ = ddim_sampler.sample(ddim_steps, num_samples, | |
shape, cond, verbose=False, eta=eta, | |
unconditional_guidance_scale=scale, | |
unconditional_conditioning=un_cond) | |
if config.save_memory: | |
model.low_vram_shift(is_diffusing=False) | |
x_samples = model.decode_first_stage(samples) | |
x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8) | |
results = [x_samples[i] for i in range(num_samples)] | |
if LOG_SAMPLES: | |
log_sample(seed, results, prompt, detected_map, mask_image, control_scales, *viscon_images) | |
return results | |
def get_image(name, file_ext='.jpg'): | |
fname = str(VISCON_IMAGE_PATH/name)+file_ext | |
if not os.path.exists(fname): | |
return None | |
return Image.open(fname) | |
def get_image_numpy(name, file_ext='.png'): | |
fname = str(VISCON_IMAGE_PATH/name)+file_ext | |
if not os.path.exists(fname): | |
return None | |
return np.array(Image.open(fname)) | |
def create_app(): | |
block = gr.Blocks().queue() | |
with block: | |
with gr.Row(): | |
gr.Markdown("## ViscoNet: Visual ControlNet with Human Pose and Fashion <br> [Video tutorial](https://youtu.be/85NyIuLeV00)") | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Accordion("Get pose and mask", open=True): | |
with gr.Row(): | |
input_image = gr.Image(source='upload', type="numpy", label='input image', value=np.array(get_image_numpy('ref'))) | |
pose_image = gr.Image(source='upload', type="numpy", label='pose', value=np.array(get_image_numpy('pose'))) | |
mask_image = gr.Image(source='upload', type="numpy", label='mask', value=np.array(get_image_numpy('mask'))) | |
with gr.Accordion("Human Pose Samples", open=False): | |
with gr.Tab('Female'): | |
samples = get_image_files(str(SAMPLE_IMAGE_PATH/'pose/WOMEN/')) | |
female_pose_gallery = gr.Gallery(label='pose', show_label=False, value=samples).style(grid=3, height='auto') | |
with gr.Tab('Male'): | |
samples = get_image_files(str(SAMPLE_IMAGE_PATH/'pose/MEN/')) | |
male_pose_gallery = gr.Gallery(label='pose', show_label=False, value=samples).style(grid=3, height='auto') | |
with gr.Row(): | |
#pad_checkbox = gr.Checkbox(label='Pad pose to square', value=True) | |
ignorehead_checkbox = gr.Checkbox(label='Ignore face in masking (for faceswap with text)', value=False) | |
ignorehair_checkbox = gr.Checkbox(label='Ignore hair in masking', value=False, visible=True) | |
with gr.Row(): | |
#ignore_head_checkbox = gr.Checkbox(label='Ignore head', value=False) | |
get_pose_button = gr.Button(label="Get pose", value='Get pose') | |
get_fashion_button = gr.Button(label="Get visual", value='Get visual prompt') | |
with gr.Accordion("Visual Conditions", open=True): | |
gr.Markdown('Drag-and-drop, or click from samples below.') | |
with gr.Column(): | |
viscon_images = [] | |
viscon_images_names2index = {} | |
viscon_len = len(style_names) | |
v_idx = 0 | |
with gr.Row(): | |
for _ in range(8): | |
viscon_name = style_names[v_idx] | |
vis = False if viscon_name in ignore_style_list else True | |
viscon_images.append(gr.Image(source='upload', type="pil", min_height=112, min_width=112, label=viscon_name, value=get_image(viscon_name), visible=vis)) | |
viscon_images_names2index[viscon_name] = v_idx | |
v_idx += 1 | |
viscon_button = gr.Button(value='Save as Default',visible=False if DEMO else True) | |
viscon_galleries = [] | |
with gr.Accordion("Virtual Try-on", open=False): | |
with gr.Column(): | |
#with gr.Accordion("Female", open=False): | |
with gr.Tab('Female'): | |
for garment, number in zip(['face', 'hair', 'top', 'bottom', 'outer'], [50, 150, 500, 500, 250]): | |
with gr.Tab(garment): | |
samples = [] | |
if WOMEN_GALLERY_PATH and os.path.exists(WOMEN_GALLERY_PATH): | |
samples = glob(os.path.join(WOMEN_GALLERY_PATH, f'**/{garment}.jpg'), recursive=True) | |
samples = random.choices(samples, k=number) | |
viscon_gallery = gr.Gallery(label='hair', allow_preview=False, show_label=False, value=samples).style(grid=4, height='auto') | |
viscon_galleries.append({'component':viscon_gallery, 'inputs':[garment]}) | |
#with gr.Accordion("Male", open=False): | |
with gr.Tab('Male'): | |
for garment, number in zip(['face','hair', 'top', 'bottom', 'outer'], [50, 150, 500, 500, 250]): | |
with gr.Tab(garment): | |
samples = [] | |
if MEN_GALLERY_PATH and os.path.exists(MEN_GALLERY_PATH): | |
samples = glob(os.path.join(MEN_GALLERY_PATH, f'**/{garment}.jpg'), recursive=True) | |
samples = random.choices(samples, k=number) | |
viscon_gallery = gr.Gallery(label='hair', allow_preview=False, show_label=False, value=samples).style(grid=4, height='auto') | |
viscon_galleries.append({'component':viscon_gallery, 'inputs':[garment]}) | |
with gr.Column(): | |
result_gallery = gr.Gallery(label='Output', show_label=False, show_download_button=True, elem_id="gallery").style(grid=1, height='auto') | |
with gr.Row(): | |
max_samples = 8 if not DEMO else 4 | |
num_samples = gr.Slider(label="Images", minimum=1, maximum=max_samples, value=1, step=1) | |
scale_all = gr.Slider(label=f'Control Strength', minimum=0, maximum=1, value=DEFAULT_CONTROL_SCALE, step=0.05) | |
seed = gr.Slider(label="Seed (-1 for random)", minimum=-1, maximum=2147483647, step=1, value=1561194236)#randomize=True) #value=1561194234) | |
if not DEMO: | |
DF_DEMO = 'fashionWOMENTees_Tanksid0000762403_1front___fashionWOMENTees_Tanksid0000762403_1front' | |
DF_EVAL = 'fashionWOMENBlouses_Shirtsid0000035501_1front___fashionWOMENBlouses_Shirtsid0000035501_1front' | |
DF_RESULT ="fashionWOMENTees_Tanksid0000796209_1front___fashionWOMENTees_Tanksid0000796209_2side" | |
deepfashion_names = gr.Textbox(label='Deepfashion name', value=DF_EVAL) | |
gr.Markdown("Default config reconstruct image faithful to pose, mask and visual condition. Reduce control strength to tip balance towards text prompt for more creativity.") | |
prompt = gr.Textbox(label="Text Prompt", value="") | |
run_button = gr.Button(label="Run") | |
with gr.Accordion("Advanced options", open=False): | |
with gr.Accordion("Control Strength Scaling", open=False): | |
gr.Markdown("smaller value for stronger textual influence. c12 is highest spatial resolution controlling textures") | |
strength_select = gr.Dropdown(list(SCALE_CONFIG.keys()), label='strength settings', value=DEFAULT_SCALE_CONFIG) | |
scale_values = SCALE_CONFIG[DEFAULT_SCALE_CONFIG] | |
control_scales = [] | |
c_idx = 12 | |
with gr.Accordion("Advanced settings", open=False): | |
with gr.Row(): | |
for _ in range(3): | |
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05)) | |
c_idx -= 1 | |
with gr.Row(): | |
for _ in range(3): | |
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05)) | |
c_idx -= 1 | |
with gr.Row(): | |
for _ in range(3): | |
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05)) | |
c_idx -= 1 | |
with gr.Row(): | |
for _ in range(4): | |
control_scales.append(gr.Slider(label=f'c{c_idx}', minimum=0, maximum=1, value=scale_values[12-c_idx], step=0.05)) | |
c_idx -= 1 | |
with gr.Row(): | |
detect_resolution = gr.Slider(label="OpenPose Resolution", minimum=128, maximum=768, value=512, step=1) | |
ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=50, value=20, step=1) | |
scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=12.0, step=0.1) | |
eta = gr.Number(label="eta (DDIM)", value=0.0, visible=False) | |
a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed') | |
n_prompt = gr.Textbox(label="Negative Prompt", | |
value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, sunglasses, hat') | |
female_pose_gallery.select(fn=select_gallery_image, inputs=None, outputs=input_image) | |
male_pose_gallery.select(fn=select_gallery_image, inputs=None, outputs=input_image) | |
for vision_gallery in viscon_galleries: | |
viscon_idx = viscon_images_names2index[vision_gallery['inputs'][0]] | |
vision_gallery['component'].select(fn=select_gallery_image, inputs=None, | |
outputs=viscon_images[viscon_idx]) | |
ips = [prompt, a_prompt, n_prompt, num_samples, ddim_steps, scale, seed, eta, mask_image, pose_image, | |
*control_scales, *viscon_images] | |
run_button.click(fn=process, inputs=ips, outputs=[result_gallery]) | |
prompt.submit(fn=process, inputs=ips, outputs=[result_gallery]) | |
get_pose_button.click(fn=extract_pose_mask, inputs=[input_image, detect_resolution, | |
ignorehead_checkbox, ignorehair_checkbox], | |
outputs=[pose_image, mask_image]) | |
get_fashion_button.click(fn=extract_fashion, inputs=input_image, outputs=[*viscon_images]) | |
viscon_button.click(fn=save_viscon_images, inputs=[*viscon_images], outputs=[*viscon_images]) | |
strength_select.select(fn=select_default_strength, inputs=[strength_select], outputs=[*control_scales]) | |
scale_all.release(fn=change_all_scales, inputs=[scale_all], outputs=[*control_scales]) | |
if not DEMO: | |
deepfashion_names.submit(fn=fetch_deepfashion, inputs=[deepfashion_names], outputs=[input_image, pose_image, mask_image, *viscon_images]) | |
return block | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--gpu', type=int, default=0, help='GPU id') | |
parser.add_argument('--config', type=str, default='./configs/visconet_v1.yaml') | |
parser.add_argument('--ckpt', type=str, default='./models/visconet_v1.pth') | |
parser.add_argument('--public_link', action='store_true', default='', help='Create public link') | |
args = parser.parse_args() | |
global device | |
global segmentor | |
global apply_openpose | |
global style_encoder | |
global model | |
global ddim_sampler | |
device = f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu' | |
config_file = args.config | |
model_ckpt = args.ckpt | |
proj_config = OmegaConf.load(config_file) | |
style_names = proj_config.dataset.train.params.style_names | |
data_root = Path(proj_config.dataset.train.params.image_root) | |
image_root = data_root/proj_config.dataset.train.params.image_dir | |
style_root = data_root/proj_config.dataset.train.params.style_dir | |
pose_root = data_root/proj_config.dataset.train.params.pose_dir | |
mask_root = data_root/proj_config.dataset.train.params.mask_dir | |
segmentor = SegmentCropper() | |
apply_openpose = OpenposeDetector() | |
if not os.path.exists(model_ckpt): | |
snapshot_download(repo_id=HF_REPO, local_dir='./models', | |
allow_patterns=os.path.basename(model_ckpt)) | |
style_encoder = instantiate_from_config(proj_config.model.style_embedding_config).to(device) | |
model = create_model(config_file).cpu() | |
model.load_state_dict(load_state_dict(model_ckpt, location=device)) | |
model = model.to(device) | |
model.cond_stage_model.device = device | |
ddim_sampler = DDIMSampler(model) | |
if not GALLERY_PATH.exists(): | |
zip_name = 'fashion.zip' | |
snapshot_download(repo_id=HF_REPO, allow_patterns=zip_name, local_dir='.') | |
from zipfile import ZipFile | |
with ZipFile(zip_name, 'r') as zip_ref: | |
zip_ref.extractall('.') | |
os.remove(zip_name) | |
# Calling the main function with parsed arguments | |
block = create_app() | |
block.launch(share=args.public_link) | |