import os import re import gradio as gr import torch import torch.nn.functional as F from torch.optim import Adam from torchvision.transforms import transforms as T import clip from tr0n.config import parse_args from tr0n.modules.models.model_stylegan import Model from tr0n.modules.models.loss import AugCosineSimLatent from tr0n.modules.optimizers.sgld import SGLD from bad_words import bad_words device = "cuda" if torch.cuda.is_available() else "cpu" model_modes = { "text": { "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-text.pth", }, "image": { "checkpoint": "https://huggingface.co/Layer6/tr0n-stylegan2-clip/resolve/main/tr0n-stylegan2-clip-image.pth", } } os.environ['TOKENIZERS_PARALLELISM'] = "false" # set config params config = parse_args(is_demo=True) config_vars = vars(config) config_vars["stylegan_gen"] = "sg2-ffhq-1024" config_vars["with_gmm"] = True config_vars["num_mixtures"] = 10 model = Model(config, device, None) model.to(device) model.eval() for p in model.translator.parameters(): p.requires_grad = False loss = AugCosineSimLatent() transforms_image = T.Compose([ T.Resize(224, interpolation=T.InterpolationMode.BICUBIC), T.CenterCrop(224), T.ToTensor(), T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), ]) checkpoint_text = torch.hub.load_state_dict_from_url(model_modes["text"]["checkpoint"], map_location="cpu") translator_state_dict_text = checkpoint_text['translator_state_dict'] checkpoint_image = torch.hub.load_state_dict_from_url(model_modes["image"]["checkpoint"], map_location="cpu") translator_state_dict_image = checkpoint_image['translator_state_dict'] # default model.translator.load_state_dict(translator_state_dict_text) css = """ a { display: inline-block; color: black !important; text-decoration: none !important; } #image-gen { height: 256px; width: 256px; margin-left: auto; margin-right: auto; } """ def _slerp(val, low, high): low_norm = low / torch.norm(low, dim=1, keepdim=True) high_norm = high / torch.norm(high, dim=1, keepdim=True) omega = torch.acos((low_norm*high_norm).sum(1)) so = torch.sin(omega) res = (torch.sin((1.0-val)*omega)/so).unsqueeze(1)*low + (torch.sin(val*omega)/so).unsqueeze(1) * high return res def model_mode_text_select(): model.translator.load_state_dict(translator_state_dict_text) def model_mode_image_select(): model.translator.load_state_dict(translator_state_dict_image) def text_to_face_generate(text): if text == "": raise gr.Error("You need to provide to provide a prompt.") for word in bad_words: if re.search(rf"\b{word}\b", text): raise gr.Error("Unsafe content found. Please try again with a different prompt.") text_tok = clip.tokenize([text], truncate=True).to(device) # initialize optimization from the translator's output with torch.no_grad(): target_clip_latent, w_mixture_logits, w_means = model(x=text_tok, x_type='text', return_after_translator=True, no_sample=True) pi = w_mixture_logits.unsqueeze(-1).repeat(1, 1, w_means.shape[-1]) # 1 x num_mixtures x w_dim w = w_means # 1 x num_mixtures x w_dim w.requires_grad = True pi.requires_grad = True optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) optimizer_pi = Adam((pi,), lr=5e-3) # optimization for _ in range(100): soft_pi = F.softmax(pi, dim=1) w_prime = soft_pi * w w_prime = w_prime.sum(dim=1) _, _, pred_clip_latent, _, _ = model(x=w_prime, x_type='gan_latent', times_augment_pred_image=50) l = loss(target_clip_latent, pred_clip_latent) l.backward() torch.nn.utils.clip_grad_norm_((w,), 1.) torch.nn.utils.clip_grad_norm_((pi,), 1.) optimizer_w.step() optimizer_pi.step() optimizer_w.zero_grad() optimizer_pi.zero_grad() # generate final image with torch.no_grad(): soft_pi = F.softmax(pi, dim=1) w_prime = soft_pi * w w_prime = w_prime.sum(dim=1) _, _, _, _, pred_image_raw = model(x=w_prime, x_type='gan_latent') pred_image = ((pred_image_raw[0]+1.)/2.).cpu() return T.ToPILImage()(pred_image) def face_to_face_interpolate(image1, image2, interp_lambda=0.5): if image1 is None or image2 is None: raise gr.Error("You need to provide two images as input.") image1_pt = transforms_image(image1).to(device) image2_pt = transforms_image(image2).to(device) # initialize optimization from the translator's output with torch.no_grad(): images_pt = torch.stack([image1_pt, image2_pt]) target_clip_latents = model.clip.encode_image(images_pt).detach().float() target_clip_latent = _slerp(interp_lambda, target_clip_latents[0].unsqueeze(0), target_clip_latents[1].unsqueeze(0)) _, _, w = model(x=target_clip_latent, x_type='clip_latent', return_after_translator=True) w.requires_grad = True optimizer_w = SGLD((w,), lr=1e-1, momentum=0.99, noise_std=0.01, device=device) # optimization for _ in range(100): _, _, pred_clip_latent, _, _ = model(x=w, x_type='gan_latent', times_augment_pred_image=50) l = loss(target_clip_latent, pred_clip_latent) l.backward() torch.nn.utils.clip_grad_norm_((w,), 1.) optimizer_w.step() optimizer_w.zero_grad() # generate final image with torch.no_grad(): _, _, _, _, pred_image_raw = model(x=w, x_type='gan_latent') pred_image = ((pred_image_raw[0]+1.)/2.).cpu() return T.ToPILImage()(pred_image) examples_text = [ "Muhammad Ali", "Tinker Bell", "A man with glasses, long black hair with sideburns and a goatee", "A child with blue eyes and straight brown hair in the sunshine", "A hairdresser", "A young boy with glasses and an angry face", "Denzel Washington", "A portrait of Angela Merkel", "President Emmanuel Macron", "Prime Minister Shinzo Abe" ] examples_image = [ ["./examples/example_1_1.jpg", "./examples/example_1_2.jpg"], ["./examples/example_2_1.jpg", "./examples/example_2_2.jpg"], ["./examples/example_3_1.jpg", "./examples/example_3_2.jpg"], ["./examples/example_4_1.jpg", "./examples/example_4_2.jpg"], ] with gr.Blocks(css=css) as demo: gr.Markdown("