File size: 9,780 Bytes
cc9204b 88b44f3 cc9204b 88b44f3 cc9204b d940272 cc9204b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import os
import re
import gradio as gr
import torch
import torch.nn.functional as F
from torch.optim import Adam
from torchvision.transforms import transforms as T
import clip
from tr0n.config import parse_args
from tr0n.modules.models.model_stylegan import Model
from tr0n.modules.models.loss import AugCosineSimLatent
from tr0n.modules.optimizers.sgld import SGLD
from bad_words import bad_words
device = "cuda" if torch.cuda.is_available() else "cpu"
model_modes = {
"text": {
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth",
},
"image": {
"checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth",
}
}
os.environ['TOKENIZERS_PARALLELISM'] = "false"
# set config params
config = parse_args(is_demo=True)
config_vars = vars(config)
config_vars["stylegan_gen"] = "sg2-ffhq-1024"
config_vars["with_gmm"] = True
config_vars["num_mixtures"] = 10
model = Model(config, device, None)
model.to(device)
model.eval()
for p in model.translator.parameters():
p.requires_grad = False
loss = AugCosineSimLatent()
transforms_image = T.Compose([
T.Resize(224, interpolation=T.InterpolationMode.BICUBIC),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
])
checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu")
translator_state_dict_text = checkpoint_text['translator_state_dict']
checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu")
translator_state_dict_image = checkpoint_image['translator_state_dict']
# default
model.translator.load_state_dict(translator_state_dict_text)
css = """
a {
display: inline-block;
color: black !important;
text-decoration: none !important;
}
#image-gen {
height: 256px;
width: 256px;
margin-left: auto;
margin-right: auto;
}
"""
def _slerp(val, low, high):
low_norm = low / torch.norm(low, dim=1, keepdim=True)
high_norm = high / torch.norm(high, dim=1, keepdim=True)
omega = torch.acos((low_norm*high_norm).sum(1))
so = torch.sin(omega)
res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high
return res
def model_mode_text_select():
model.translator.load_state_dict(translator_state_dict_text)
def model_mode_image_select():
model.translator.load_state_dict(translator_state_dict_image)
def text_to_face_generate(text):
if text == "":
raise gr.Error("You need to provide to provide a prompt.")
for word in bad_words:
if re.search(rf"\b{word}\b", text):
raise gr.Error("Unsafe content found. Please try again with a different prompt.")
text_tok = clip.tokenize([text], truncate=True).to(device)
# initialize optimization from the translator's output
with torch.no_grad():
target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True)
pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) # 1 x num_mixtures x w_dim
w = w_means # 1 x num_mixtures x w_dim
w.requires_grad = True
pi.requires_grad = True
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
optimizer_pi = Adam((pi,), lr=5e-3)
# optimization
for _ in range(1):
soft_pi = F.softmax(pi, dim=1)
w_prime = soft_pi * w
w_prime = w_prime.sum(dim=1)
_, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50)
l = loss(target_clip_latent, pred_clip_latent)
l.backward()
torch.nn.utils.clip_grad_norm_((w,), 1.)
torch.nn.utils.clip_grad_norm_((pi,), 1.)
optimizer_w.step()
optimizer_pi.step()
optimizer_w.zero_grad()
optimizer_pi.zero_grad()
# generate final image
with torch.no_grad():
soft_pi = F.softmax(pi, dim=1)
w_prime = soft_pi * w
w_prime = w_prime.sum(dim=1)
_, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent')
pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
return T.ToPILImage()(pred_image)
def face_to_face_interpolate(image1, image2, interp_lambda=0.5):
if image1 is None or image2 is None:
raise gr.Error("You need to provide two images as input.")
image1_pt = transforms_image(image1).to(device)
image2_pt = transforms_image(image2).to(device)
# initialize optimization from the translator's output
with torch.no_grad():
images_pt = torch.stack([image1_pt, image2_pt])
target_clip_latents = model.clip.encode_image(images_pt).detach().float()
target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0))
_, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True)
w.requires_grad = True
optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device)
# optimization
for _ in range(1):
_, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50)
l = loss(target_clip_latent, pred_clip_latent)
l.backward()
torch.nn.utils.clip_grad_norm_((w,), 1.)
optimizer_w.step()
optimizer_w.zero_grad()
# generate final image
with torch.no_grad():
_, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent')
pred_image = ((pred_image_raw[0]+1.)/2.).cpu()
return T.ToPILImage()(pred_image)
examples_text = [
"Muhammad Ali",
"Tinker Bell",
"A man with glasses, long black hair with sideburns and a goatee",
"A child with blue eyes and straight brown hair in the sunshine",
"A hairdresser",
"A young boy with glasses and an angry face",
"Denzel Washington",
"A portrait of Angela Merkel",
"President Emmanuel Macron",
"President Xi Jinping"
]
examples_image = [
["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"],
["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"],
["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"],
["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"],
]
with gr.Blocks(css=css) as demo:
gr.Markdown("<h1><center>TR0N Face Generation Demo</center></h1>")
gr.Markdown("<h3><center><a href='https://layer6.ai/'>by Layer 6 AI</a></center></h3>")
gr.Markdown("""<p align='middle'>
<a href='https://arxiv.org/abs/2304.13742'><img src='https://img.shields.io/badge/arXiv-2304.13742-b31b1b.svg' /></a>
<a href='https://github.com/layer6ai-labs/tr0n'><img src='https://badgen.net/badge/icon/github?icon=github&label' /></a>
</p>""")
gr.Markdown("We introduce TR0N, a simple and efficient method to add any type of conditioning to pre-trained generative models. For this demo, we add two types of conditioning to a StyleGAN2 model pre-trained on images of human faces. First, we add text-conditioning to turn StyleGAN2 into a text-to-face model. Second, we add image semantic conditioning to StyleGAN2 to enable face-to-face interpolation. For more details and results on many other generative models, please refer to our paper linked above.")
with gr.Tab("Text-to-face generation") as text_to_face_generation_demo:
text_to_face_generation_input = gr.Textbox(label="Enter your prompt", placeholder="e.g. A man with a beard and glasses", max_lines=1)
text_to_face_generation_button = gr.Button("Generate")
text_to_face_generation_output = gr.Image(label="Generated image", elem_id="image-gen")
text_to_face_generation_examples = gr.Examples(examples=examples_text, fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
with gr.Tab("Face-to-face interpolation") as face_to_face_interpolation_demo:
gr.Markdown("We note that interpolations are not expected to recover the given images, even when the coefficient is 0 or 1.")
with gr.Row():
face_to_face_interpolation_input1 = gr.Image(label="Image 1", type="pil")
face_to_face_interpolation_input2 = gr.Image(label="Image 2", type="pil")
face_to_face_interpolation_lambda = gr.Slider(label="Interpolation coefficient", minimum=0, maximum=1, value=0.5, step=0.01)
face_to_face_interpolation_button = gr.Button("Interpolate")
face_to_face_interpolation_output = gr.Image(label="Interpolated image", elem_id="image-gen")
face_to_face_interpolation_examples = gr.Examples(examples=examples_image, fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
text_to_face_generation_demo.select(fn=model_mode_text_select)
text_to_face_generation_input.submit(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
text_to_face_generation_button.click(fn=text_to_face_generate, inputs=text_to_face_generation_input, outputs=text_to_face_generation_output)
face_to_face_interpolation_demo.select(fn=model_mode_image_select)
face_to_face_interpolation_button.click(fn=face_to_face_interpolate, inputs=[face_to_face_interpolation_input1, face_to_face_interpolation_input2, face_to_face_interpolation_lambda], outputs=face_to_face_interpolation_output)
demo.queue()
demo.launch()
|