full_gaussian_avatar / GHA /lib /trainer /GaussianHeadTrainer.py
pengc02's picture
all
ec9a6bc
import torch
import torch.nn.functional as F
from tqdm import tqdm
import random
import lpips
class GaussianHeadTrainer():
def __init__(self, dataloader, delta_poses, gaussianhead, supres, camera, optimizer, recorder, gpu_id):
self.dataloader = dataloader
self.delta_poses = delta_poses
self.gaussianhead = gaussianhead
self.supres = supres
self.camera = camera
self.optimizer = optimizer
self.recorder = recorder
self.device = torch.device('cuda:%d' % gpu_id)
self.fn_lpips = lpips.LPIPS(net='vgg').to(self.device)
def train(self, start_epoch=0, epochs=1):
for epoch in range(start_epoch, epochs):
for idx, data in tqdm(enumerate(self.dataloader)):
# prepare data
to_cuda = ['images', 'masks', 'visibles', 'images_coarse', 'masks_coarse', 'visibles_coarse',
'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center',
'pose', 'scale', 'exp_coeff', 'landmarks_3d', 'exp_id']
for data_item in to_cuda:
data[data_item] = data[data_item].to(device=self.device)
images = data['images']
visibles = data['visibles']
if self.supres is None:
images_coarse = images
visibles_coarse = visibles
else:
images_coarse = data['images_coarse']
visibles_coarse = data['visibles_coarse']
resolution_coarse = images_coarse.shape[2]
resolution_fine = images.shape[2]
data['pose'] = data['pose'] + self.delta_poses[data['exp_id'], :]
# render coarse images
data = self.gaussianhead.generate(data)
data = self.camera.render_gaussian(data, resolution_coarse)
render_images = data['render_images']
# crop images for augmentation
scale_factor = random.random() * 0.45 + 0.8
scale_factor = int(resolution_coarse * scale_factor) / resolution_coarse
cropped_render_images, cropped_images, cropped_visibles = self.random_crop(render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine)
data['cropped_images'] = cropped_images
# generate super resolution images
supres_images = self.supres(cropped_render_images)
data['supres_images'] = supres_images
# loss functions
loss_rgb_lr = F.l1_loss(render_images[:, 0:3, :, :] * visibles_coarse, images_coarse * visibles_coarse)
loss_rgb_hr = F.l1_loss(supres_images * cropped_visibles, cropped_images * cropped_visibles)
left_up = (random.randint(0, supres_images.shape[2] - 512), random.randint(0, supres_images.shape[3] - 512))
loss_vgg = self.fn_lpips((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512],
(cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], normalize=True).mean()
loss = loss_rgb_hr + loss_rgb_lr + loss_vgg * 1e-1
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
log = {
'data': data,
'delta_poses' : self.delta_poses,
'gaussianhead' : self.gaussianhead,
'supres' : self.supres,
'loss_rgb_lr' : loss_rgb_lr,
'loss_rgb_hr' : loss_rgb_hr,
'loss_vgg' : loss_vgg,
'epoch' : epoch,
'iter' : idx + epoch * len(self.dataloader)
}
self.recorder.log(log)
def random_crop(self, render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine):
render_images_scaled = F.interpolate(render_images, scale_factor=scale_factor)
images_scaled = F.interpolate(images, scale_factor=scale_factor)
visibles_scaled = F.interpolate(visibles, scale_factor=scale_factor)
if scale_factor < 1:
render_images = torch.ones([render_images_scaled.shape[0], render_images_scaled.shape[1], resolution_coarse, resolution_coarse], device=self.device)
left_up_coarse = (random.randint(0, resolution_coarse - render_images_scaled.shape[2]), random.randint(0, resolution_coarse - render_images_scaled.shape[3]))
render_images[:, :, left_up_coarse[0]: left_up_coarse[0] + render_images_scaled.shape[2], left_up_coarse[1]: left_up_coarse[1] + render_images_scaled.shape[3]] = render_images_scaled
images = torch.ones([images_scaled.shape[0], images_scaled.shape[1], resolution_fine, resolution_fine], device=self.device)
visibles = torch.ones([visibles_scaled.shape[0], visibles_scaled.shape[1], resolution_fine, resolution_fine], device=self.device)
left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse))
images[:, :, left_up_fine[0]: left_up_fine[0] + images_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + images_scaled.shape[3]] = images_scaled
visibles[:, :, left_up_fine[0]: left_up_fine[0] + visibles_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + visibles_scaled.shape[3]] = visibles_scaled
else:
left_up_coarse = (random.randint(0, render_images_scaled.shape[2] - resolution_coarse), random.randint(0, render_images_scaled.shape[3] - resolution_coarse))
render_images = render_images_scaled[:, :, left_up_coarse[0]: left_up_coarse[0] + resolution_coarse, left_up_coarse[1]: left_up_coarse[1] + resolution_coarse]
left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse))
images = images_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine]
visibles = visibles_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine]
return render_images, images, visibles