StyleGANEX / utils /inference_utils.py
PKUWilliamYang's picture
Upload 8 files
ac1883f
raw
history blame
6.21 kB
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import random
import math
import argparse
import torch
from torch.utils import data
from torch.nn import functional as F
from torch import autograd
from torch.nn import init
import torchvision.transforms as transforms
from scripts.align_all_parallel import get_landmark
def visualize(img_arr, dpi):
plt.figure(figsize=(10,10),dpi=dpi)
plt.imshow(((img_arr.detach().cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8))
plt.axis('off')
plt.show()
def save_image(img, filename):
tmp = ((img.detach().cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
cv2.imwrite(filename, cv2.cvtColor(tmp, cv2.COLOR_RGB2BGR))
def load_image(filename):
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],std=[0.5,0.5,0.5]),
])
img = Image.open(filename)
img = transform(img)
return img.unsqueeze(dim=0)
def get_video_crop_parameter(filepath, predictor, padding=[256,256,256,256]):
if type(filepath) == str:
img = dlib.load_rgb_image(filepath)
else:
img = filepath
lm = get_landmark(img, predictor)
if lm is None:
return None
lm_chin = lm[0 : 17] # left-right
lm_eyebrow_left = lm[17 : 22] # left-right
lm_eyebrow_right = lm[22 : 27] # left-right
lm_nose = lm[27 : 31] # top-down
lm_nostrils = lm[31 : 36] # top-down
lm_eye_left = lm[36 : 42] # left-clockwise
lm_eye_right = lm[42 : 48] # left-clockwise
lm_mouth_outer = lm[48 : 60] # left-clockwise
lm_mouth_inner = lm[60 : 68] # left-clockwise
scale = 64. / (np.mean(lm_eye_right[:,0])-np.mean(lm_eye_left[:,0]))
center = ((np.mean(lm_eye_right, axis=0)+np.mean(lm_eye_left, axis=0)) / 2) * scale
h, w = round(img.shape[0] * scale), round(img.shape[1] * scale)
left = max(round(center[0] - padding[0]), 0) // 8 * 8
right = min(round(center[0] + padding[1]), w) // 8 * 8
top = max(round(center[1] - padding[2]), 0) // 8 * 8
bottom = min(round(center[1] + padding[3]), h) // 8 * 8
return h,w,top,bottom,left,right,scale
def tensor2cv2(img):
tmp = ((img.cpu().numpy().transpose(1, 2, 0) + 1.0) * 127.5).astype(np.uint8)
return cv2.cvtColor(tmp, cv2.COLOR_RGB2BGR)
def noise_regularize(noises):
loss = 0
for noise in noises:
size = noise.shape[2]
while True:
loss = (
loss
+ (noise * torch.roll(noise, shifts=1, dims=3)).mean().pow(2)
+ (noise * torch.roll(noise, shifts=1, dims=2)).mean().pow(2)
)
if size <= 8:
break
#noise = noise.reshape([-1, 1, size // 2, 2, size // 2, 2])
#noise = noise.mean([3, 5])
noise = F.interpolate(noise, scale_factor=0.5, mode='bilinear')
size //= 2
return loss
def noise_normalize_(noises):
for noise in noises:
mean = noise.mean()
std = noise.std()
noise.data.add_(-mean).div_(std)
def get_lr(t, initial_lr, rampdown=0.25, rampup=0.05):
lr_ramp = min(1, (1 - t) / rampdown)
lr_ramp = 0.5 - 0.5 * math.cos(lr_ramp * math.pi)
lr_ramp = lr_ramp * min(1, t / rampup)
return initial_lr * lr_ramp
def latent_noise(latent, strength):
noise = torch.randn_like(latent) * strength
return latent + noise
def make_image(tensor):
return (
tensor.detach()
.clamp_(min=-1, max=1)
.add(1)
.div_(2)
.mul(255)
.type(torch.uint8)
.permute(0, 2, 3, 1)
.to("cpu")
.numpy()
)
# from pix2pixeHD
# Converts a one-hot tensor into a colorful label map
def tensor2label(label_tensor, n_label, imtype=np.uint8):
if n_label == 0:
return tensor2im(label_tensor, imtype)
label_tensor = label_tensor.cpu().float()
if label_tensor.size()[0] > 1:
label_tensor = label_tensor.max(0, keepdim=True)[1]
label_tensor = Colorize(n_label)(label_tensor)
label_numpy = np.transpose(label_tensor.numpy(), (1, 2, 0))
return label_numpy.astype(imtype)
def uint82bin(n, count=8):
"""returns the binary of integer n, count refers to amount of bits"""
return ''.join([str((n >> y) & 1) for y in range(count-1, -1, -1)])
def labelcolormap(N):
if N == 35: # cityscape
cmap = np.array([( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), ( 0, 0, 0), (111, 74, 0), ( 81, 0, 81),
(128, 64,128), (244, 35,232), (250,170,160), (230,150,140), ( 70, 70, 70), (102,102,156), (190,153,153),
(180,165,180), (150,100,100), (150,120, 90), (153,153,153), (153,153,153), (250,170, 30), (220,220, 0),
(107,142, 35), (152,251,152), ( 70,130,180), (220, 20, 60), (255, 0, 0), ( 0, 0,142), ( 0, 0, 70),
( 0, 60,100), ( 0, 0, 90), ( 0, 0,110), ( 0, 80,100), ( 0, 0,230), (119, 11, 32), ( 0, 0,142)],
dtype=np.uint8)
else:
cmap = np.zeros((N, 3), dtype=np.uint8)
for i in range(N):
r, g, b = 0, 0, 0
id = i
for j in range(7):
str_id = uint82bin(id)
r = r ^ (np.uint8(str_id[-1]) << (7-j))
g = g ^ (np.uint8(str_id[-2]) << (7-j))
b = b ^ (np.uint8(str_id[-3]) << (7-j))
id = id >> 3
cmap[i, 0] = r
cmap[i, 1] = g
cmap[i, 2] = b
return cmap
class Colorize(object):
def __init__(self, n=35):
self.cmap = labelcolormap(n)
self.cmap = torch.from_numpy(self.cmap[:n])
def __call__(self, gray_image):
size = gray_image.size()
color_image = torch.ByteTensor(3, size[1], size[2]).fill_(0)
for label in range(0, len(self.cmap)):
mask = (label == gray_image[0]).cpu()
color_image[0][mask] = self.cmap[label][0]
color_image[1][mask] = self.cmap[label][1]
color_image[2][mask] = self.cmap[label][2]
return color_image