|
import os |
|
import re |
|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from torch.optim import Adam |
|
from torchvision.transforms import transforms as T |
|
import clip |
|
from tr0n.config import parse_args |
|
from tr0n.modules.models.model_stylegan import Model |
|
from tr0n.modules.models.loss import AugCosineSimLatent |
|
from tr0n.modules.optimizers.sgld import SGLD |
|
from bad_words import bad_words |
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model_modes = { |
|
"text": { |
|
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth", |
|
}, |
|
"image": { |
|
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth", |
|
} |
|
} |
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = "false" |
|
|
|
|
|
|
|
config = parse_args(is_demo=True) |
|
config_vars = vars(config) |
|
config_vars["stylegan_gen"] = "sg2-ffhq-1024" |
|
config_vars["with_gmm"] = True |
|
config_vars["num_mixtures"] = 10 |
|
|
|
|
|
model = Model(config, device, None) |
|
model.to(device) |
|
model.eval() |
|
for p in model.translator.parameters(): |
|
p.requires_grad = False |
|
loss = AugCosineSimLatent() |
|
|
|
|
|
transforms_image = T.Compose([ |
|
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), |
|
T.CenterCrop(224), |
|
T.ToTensor(), |
|
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
|
|
|
|
checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu") |
|
translator_state_dict_text = checkpoint_text['translator_state_dict'] |
|
checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu") |
|
translator_state_dict_image = checkpoint_image['translator_state_dict'] |
|
|
|
|
|
model.translator.load_state_dict(translator_state_dict_text) |
|
|
|
|
|
css = """ |
|
a { |
|
display: inline-block; |
|
color: black !important; |
|
text-decoration: none !important; |
|
} |
|
#image-gen { |
|
height: 256px; |
|
width: 256px; |
|
margin-left: auto; |
|
margin-right: auto; |
|
} |
|
""" |
|
|
|
|
|
def _slerp(val, low, high): |
|
low_norm = low / torch.norm(low, dim=1, keepdim=True) |
|
high_norm = high / torch.norm(high, dim=1, keepdim=True) |
|
omega = torch.acos((low_norm*high_norm).sum(1)) |
|
so = torch.sin(omega) |
|
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high |
|
return res |
|
|
|
|
|
def model_mode_text_select(): |
|
model.translator.load_state_dict(translator_state_dict_text) |
|
|
|
|
|
def model_mode_image_select(): |
|
model.translator.load_state_dict(translator_state_dict_image) |
|
|
|
|
|
def text_to_face_generate(text): |
|
if text == "": |
|
raise gr.Error("You need to provide to provide a prompt.") |
|
|
|
for word in bad_words: |
|
if re.search(rf"\b{word}\b", text): |
|
raise gr.Error("Unsafe content found. Please try again with a different prompt.") |
|
|
|
text_tok = clip.tokenize([text], truncate=True).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True) |
|
pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) |
|
w = w_means |
|
|
|
w.requires_grad = True |
|
pi.requires_grad = True |
|
|
|
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) |
|
optimizer_pi = Adam((pi,), lr=5e-3) |
|
|
|
|
|
for _ in range(100): |
|
soft_pi = F.softmax(pi, dim=1) |
|
w_prime = soft_pi * w |
|
w_prime = w_prime.sum(dim=1) |
|
|
|
_, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50) |
|
|
|
l = loss(target_clip_latent, pred_clip_latent) |
|
l.backward() |
|
torch.nn.utils.clip_grad_norm_((w,), 1.) |
|
torch.nn.utils.clip_grad_norm_((pi,), 1.) |
|
optimizer_w.step() |
|
optimizer_pi.step() |
|
optimizer_w.zero_grad() |
|
optimizer_pi.zero_grad() |
|
|
|
|
|
with torch.no_grad(): |
|
soft_pi = F.softmax(pi, dim=1) |
|
w_prime = soft_pi * w |
|
w_prime = w_prime.sum(dim=1) |
|
|
|
_, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent') |
|
|
|
pred_image = ((pred_image_raw[0]+1.)/2.).cpu() |
|
return T.ToPILImage()(pred_image) |
|
|
|
|
|
def face_to_face_interpolate(image1, image2, interp_lambda=0.5): |
|
if image1 is None or image2 is None: |
|
raise gr.Error("You need to provide two images as input.") |
|
|
|
image1_pt = transforms_image(image1).to(device) |
|
image2_pt = transforms_image(image2).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
images_pt = torch.stack([image1_pt, image2_pt]) |
|
target_clip_latents = model.clip.encode_image(images_pt).detach().float() |
|
target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0)) |
|
_, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True) |
|
|
|
w.requires_grad = True |
|
|
|
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) |
|
|
|
|
|
for _ in range(100): |
|
_, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50) |
|
|
|
l = loss(target_clip_latent, pred_clip_latent) |
|
l.backward() |
|
torch.nn.utils.clip_grad_norm_((w,), 1.) |
|
optimizer_w.step() |
|
optimizer_w.zero_grad() |
|
|
|
|
|
with torch.no_grad(): |
|
_, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent') |
|
|
|
pred_image = ((pred_image_raw[0]+1.)/2.).cpu() |
|
return T.ToPILImage()(pred_image) |
|
|
|
|
|
examples_text = [ |
|
"Muhammad Ali", |
|
"Tinker Bell", |
|
"A man with glasses, long black hair with sideburns and a goatee", |
|
"A child with blue eyes and straight brown hair in the sunshine", |
|
"A hairdresser", |
|
"A young boy with glasses and an angry face", |
|
"Denzel Washington", |
|
"A portrait of Angela Merkel", |
|
"President Emmanuel Macron", |
|
"Prime Minister Shinzo Abe" |
|
] |
|
|
|
examples_image = [ |
|
["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"], |
|
["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"], |
|
["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"], |
|
["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"], |
|
] |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
gr.Markdown("<h1><center>TR0N Face Generation Demo</center></h1>") |
|
gr.Markdown("<h3><center><a href='https://layer6.ai/'>by Layer 6 AI</a></center></h3>") |
|
gr.Markdown("""<p align='middle'> |
|
<a href='https://arxiv.org/abs/2304.13742'><img src='https://img.shields.io/badge/arXiv-2304.13742-b31b1b.svg' /></a> |
|
<a href='https://github.com/layer6ai-labs/tr0n'><img src='https://badgen.net/badge/icon/github?icon=github&label' /></a> |
|
</p>""") |
|
gr.Markdown("We introduce TR0N, a simple and efficient method to add any type of conditioning to pre-trained generative models. For this demo, we add two types of conditioning to a StyleGAN2 model pre-trained on images of human faces. First, we add text-conditioning to turn StyleGAN2 into a text-to-face model. Second, we add image semantic conditioning to StyleGAN2 to enable face-to-face interpolation. For more details and results on many other generative models, please refer to our paper linked above.") |
|
|
|
with gr.Tab("Text-to-face generation") as text_to_face_generation_demo: |
|
text_to_face_generation_input = gr.Textbox(label="Enter your prompt", placeholder="e.g. A man with a beard and glasses", max_lines=1) |
|
text_to_face_generation_button = gr.Button("Generate") |
|
text_to_face_generation_output = gr.Image(label="Generated image", elem_id="image-gen") |
|
text_to_face_generation_examples = gr.Examples(examples=examples_text, fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) |
|
|
|
with gr.Tab("Face-to-face interpolation") as face_to_face_interpolation_demo: |
|
gr.Markdown("We note that interpolations are not expected to recover the given images, even when the coefficient is 0 or 1.") |
|
with gr.Row(): |
|
face_to_face_interpolation_input1 = gr.Image(label="Image 1", type="pil") |
|
face_to_face_interpolation_input2 = gr.Image(label="Image 2", type="pil") |
|
face_to_face_interpolation_lambda = gr.Slider(label="Interpolation coefficient", minimum=0, maximum=1, value=0.5, step=0.01) |
|
face_to_face_interpolation_button = gr.Button("Interpolate") |
|
face_to_face_interpolation_output = gr.Image(label="Interpolated image", elem_id="image-gen") |
|
face_to_face_interpolation_examples = gr.Examples(examples=examples_image, fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output) |
|
|
|
text_to_face_generation_demo.select(fn=model_mode_text_select) |
|
text_to_face_generation_input.submit(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) |
|
text_to_face_generation_button.click(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output) |
|
|
|
face_to_face_interpolation_demo.select(fn=model_mode_image_select) |
|
face_to_face_interpolation_button.click(fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output) |
|
|
|
|
|
demo.queue() |
|
demo.launch() |
|
|