DualStyleGAN / app.py
hysts's picture
hysts HF staff
Update
6af0602
raw
history blame
9.17 kB
#!/usr/bin/env python
from __future__ import annotations
import argparse
import functools
import os
import pathlib
import sys
from typing import Callable
if os.environ.get('SYSTEM') == 'spaces':
os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/fused_act.py")
os.system("sed -i '10,17d' DualStyleGAN/model/stylegan/op/upfirdn2d.py")
sys.path.insert(0, 'DualStyleGAN')
import dlib
import gradio as gr
import huggingface_hub
import numpy as np
import PIL.Image
import torch
import torch.nn as nn
import torchvision.transforms as T
from model.dualstylegan import DualStyleGAN
from model.encoder.align_all_parallel import align_face
from model.encoder.psp import pSp
from util import load_image, visualize
ORIGINAL_REPO_URL = 'https://github.com/williamyang1991/DualStyleGAN'
TITLE = 'williamyang1991/DualStyleGAN'
DESCRIPTION = f"""A demo for {ORIGINAL_REPO_URL}
You can select style images for cartoon from the table below.
(Currently, style image tables for other style types are not available.)
The style image index should be in the following range:
(cartoon: 0-316, caricature: 0-198, anime: 0-173, arcane: 0-99, comic: 0-100, pixar: 0-121, slamdunk: 0-119)
"""
ARTICLE = '![cartoon style images](https://user-images.githubusercontent.com/18130694/159848447-96fa5194-32ec-42f0-945a-3b1958bf6e5e.jpg)'
TOKEN = os.environ['TOKEN']
MODEL_REPO = 'hysts/DualStyleGAN'
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument('--device', type=str, default='cpu')
parser.add_argument('--theme', type=str)
parser.add_argument('--live', action='store_true')
parser.add_argument('--share', action='store_true')
parser.add_argument('--port', type=int)
parser.add_argument('--disable-queue',
dest='enable_queue',
action='store_false')
parser.add_argument('--allow-flagging', type=str, default='never')
parser.add_argument('--allow-screenshot', action='store_true')
return parser.parse_args()
def load_encoder(device: torch.device) -> nn.Module:
ckpt_path = huggingface_hub.hf_hub_download(MODEL_REPO,
'models/encoder.pt',
use_auth_token=TOKEN)
ckpt = torch.load(ckpt_path, map_location='cpu')
opts = ckpt['opts']
opts['device'] = device.type
opts['checkpoint_path'] = ckpt_path
opts = argparse.Namespace(**opts)
model = pSp(opts)
model.to(device)
model.eval()
return model
def load_generator(style_type: str, device: torch.device) -> nn.Module:
model = DualStyleGAN(1024, 512, 8, 2, res_index=6)
ckpt_path = huggingface_hub.hf_hub_download(
MODEL_REPO, f'models/{style_type}/generator.pt', use_auth_token=TOKEN)
ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt['g_ema'])
model.to(device)
model.eval()
return model
def load_exstylecode(style_type: str) -> dict[str, np.ndarray]:
if style_type in ['cartoon', 'caricature', 'anime']:
filename = 'refined_exstyle_code.npy'
else:
filename = 'exstyle_code.npy'
path = huggingface_hub.hf_hub_download(MODEL_REPO,
f'models/{style_type}/{filename}',
use_auth_token=TOKEN)
exstyles = np.load(path, allow_pickle=True).item()
return exstyles
def create_transform() -> Callable:
transform = T.Compose([
T.Resize(256),
T.CenterCrop(256),
T.ToTensor(),
T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
return transform
def create_dlib_landmark_model():
path = huggingface_hub.hf_hub_download(
'hysts/dlib_face_landmark_model',
'shape_predictor_68_face_landmarks.dat',
use_auth_token=TOKEN)
return dlib.shape_predictor(path)
def denormalize(tensor: torch.Tensor) -> torch.Tensor:
return torch.clamp((tensor + 1) / 2 * 255, 0, 255).to(torch.uint8)
def postprocess(tensor: torch.Tensor) -> PIL.Image.Image:
tensor = denormalize(tensor)
image = tensor.cpu().numpy().transpose(1, 2, 0)
return PIL.Image.fromarray(image)
@torch.inference_mode()
def run(
image,
style_type: str,
style_id: float,
dlib_landmark_model,
encoder: nn.Module,
generator_dict: dict[str, nn.Module],
exstyle_dict: dict[str, dict[str, np.ndarray]],
transform: Callable,
device: torch.device,
) -> tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image,
PIL.Image, PIL.Image]:
generator = generator_dict[style_type]
exstyles = exstyle_dict[style_type]
style_id = int(style_id)
style_id = min(max(0, style_id), len(exstyles) - 1)
stylename = list(exstyles.keys())[style_id]
image = align_face(filepath=image.name, predictor=dlib_landmark_model)
input_data = transform(image).unsqueeze(0).to(device)
img_rec, instyle = encoder(input_data,
randomize_noise=False,
return_latents=True,
z_plus_latent=True,
return_z_plus_latent=True,
resize=False)
img_rec = torch.clamp(img_rec.detach(), -1, 1)
latent = torch.tensor(exstyles[stylename]).repeat(2, 1, 1).to(device)
# latent[0] for both color and structrue transfer and latent[1] for only structrue transfer
latent[1, 7:18] = instyle[0, 7:18]
exstyle = generator.generator.style(
latent.reshape(latent.shape[0] * latent.shape[1],
latent.shape[2])).reshape(latent.shape)
img_gen, _ = generator([instyle.repeat(2, 1, 1)],
exstyle,
z_plus_latent=True,
truncation=0.7,
truncation_latent=0,
use_res=True,
interp_weights=[0.6] * 7 + [1] * 11)
img_gen = torch.clamp(img_gen.detach(), -1, 1)
# deactivate color-related layers by setting w_c = 0
img_gen2, _ = generator([instyle],
exstyle[0:1],
z_plus_latent=True,
truncation=0.7,
truncation_latent=0,
use_res=True,
interp_weights=[0.6] * 7 + [0] * 11)
img_gen2 = torch.clamp(img_gen2.detach(), -1, 1)
img_rec = postprocess(img_rec[0])
img_gen0 = postprocess(img_gen[0])
img_gen1 = postprocess(img_gen[1])
img_gen2 = postprocess(img_gen2[0])
return image, img_rec, img_gen0, img_gen1, img_gen2
def main():
gr.close_all()
args = parse_args()
device = torch.device(args.device)
style_types = [
'cartoon',
'caricature',
'anime',
'arcane',
'comic',
'pixar',
'slamdunk',
]
generator_dict = {
style_type: load_generator(style_type, device)
for style_type in style_types
}
exstyle_dict = {
style_type: load_exstylecode(style_type)
for style_type in style_types
}
dlib_landmark_model = create_dlib_landmark_model()
encoder = load_encoder(device)
transform = create_transform()
func = functools.partial(run,
dlib_landmark_model=dlib_landmark_model,
encoder=encoder,
generator_dict=generator_dict,
exstyle_dict=exstyle_dict,
transform=transform,
device=device)
func = functools.update_wrapper(func, run)
image_paths = sorted(pathlib.Path('images').glob('*.jpg'))
examples = [[path.as_posix(), 'cartoon', 26] for path in image_paths]
gr.Interface(
func,
[
gr.inputs.Image(type='file', label='Input Image'),
gr.inputs.Radio(
style_types,
type='value',
default='cartoon',
label='Style Type',
),
gr.inputs.Number(default=26, label='Style Image Index'),
],
[
gr.outputs.Image(type='pil', label='Aligned Face'),
gr.outputs.Image(type='pil', label='Reconstructed'),
gr.outputs.Image(type='pil',
label='Result 1 (Color and structure transfer)'),
gr.outputs.Image(type='pil',
label='Result 2 (Structure transfer only)'),
gr.outputs.Image(
type='pil',
label='Result 3 (Color-related layers deactivated)'),
],
examples=examples,
theme=args.theme,
title=TITLE,
description=DESCRIPTION,
article=ARTICLE,
allow_screenshot=args.allow_screenshot,
allow_flagging=args.allow_flagging,
live=args.live,
).launch(
enable_queue=args.enable_queue,
server_port=args.port,
share=args.share,
)
if __name__ == '__main__':
main()