Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn.functional as F | |
from tqdm import tqdm | |
import random | |
import lpips | |
class GaussianHeadTrainer(): | |
def __init__(self, dataloader, delta_poses, gaussianhead, supres, camera, optimizer, recorder, gpu_id): | |
self.dataloader = dataloader | |
self.delta_poses = delta_poses | |
self.gaussianhead = gaussianhead | |
self.supres = supres | |
self.camera = camera | |
self.optimizer = optimizer | |
self.recorder = recorder | |
self.device = torch.device('cuda:%d' % gpu_id) | |
self.fn_lpips = lpips.LPIPS(net='vgg').to(self.device) | |
def train(self, start_epoch=0, epochs=1): | |
for epoch in range(start_epoch, epochs): | |
for idx, data in tqdm(enumerate(self.dataloader)): | |
# prepare data | |
to_cuda = ['images', 'masks', 'visibles', 'images_coarse', 'masks_coarse', 'visibles_coarse', | |
'intrinsics', 'extrinsics', 'world_view_transform', 'projection_matrix', 'full_proj_transform', 'camera_center', | |
'pose', 'scale', 'exp_coeff', 'landmarks_3d', 'exp_id'] | |
for data_item in to_cuda: | |
data[data_item] = data[data_item].to(device=self.device) | |
images = data['images'] | |
visibles = data['visibles'] | |
if self.supres is None: | |
images_coarse = images | |
visibles_coarse = visibles | |
else: | |
images_coarse = data['images_coarse'] | |
visibles_coarse = data['visibles_coarse'] | |
resolution_coarse = images_coarse.shape[2] | |
resolution_fine = images.shape[2] | |
data['pose'] = data['pose'] + self.delta_poses[data['exp_id'], :] | |
# render coarse images | |
data = self.gaussianhead.generate(data) | |
data = self.camera.render_gaussian(data, resolution_coarse) | |
render_images = data['render_images'] | |
# crop images for augmentation | |
scale_factor = random.random() * 0.45 + 0.8 | |
scale_factor = int(resolution_coarse * scale_factor) / resolution_coarse | |
cropped_render_images, cropped_images, cropped_visibles = self.random_crop(render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine) | |
data['cropped_images'] = cropped_images | |
# generate super resolution images | |
supres_images = self.supres(cropped_render_images) | |
data['supres_images'] = supres_images | |
# loss functions | |
loss_rgb_lr = F.l1_loss(render_images[:, 0:3, :, :] * visibles_coarse, images_coarse * visibles_coarse) | |
loss_rgb_hr = F.l1_loss(supres_images * cropped_visibles, cropped_images * cropped_visibles) | |
left_up = (random.randint(0, supres_images.shape[2] - 512), random.randint(0, supres_images.shape[3] - 512)) | |
loss_vgg = self.fn_lpips((supres_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], | |
(cropped_images * cropped_visibles)[:, :, left_up[0]:left_up[0]+512, left_up[1]:left_up[1]+512], normalize=True).mean() | |
loss = loss_rgb_hr + loss_rgb_lr + loss_vgg * 1e-1 | |
self.optimizer.zero_grad() | |
loss.backward() | |
self.optimizer.step() | |
log = { | |
'data': data, | |
'delta_poses' : self.delta_poses, | |
'gaussianhead' : self.gaussianhead, | |
'supres' : self.supres, | |
'loss_rgb_lr' : loss_rgb_lr, | |
'loss_rgb_hr' : loss_rgb_hr, | |
'loss_vgg' : loss_vgg, | |
'epoch' : epoch, | |
'iter' : idx + epoch * len(self.dataloader) | |
} | |
self.recorder.log(log) | |
def random_crop(self, render_images, images, visibles, scale_factor, resolution_coarse, resolution_fine): | |
render_images_scaled = F.interpolate(render_images, scale_factor=scale_factor) | |
images_scaled = F.interpolate(images, scale_factor=scale_factor) | |
visibles_scaled = F.interpolate(visibles, scale_factor=scale_factor) | |
if scale_factor < 1: | |
render_images = torch.ones([render_images_scaled.shape[0], render_images_scaled.shape[1], resolution_coarse, resolution_coarse], device=self.device) | |
left_up_coarse = (random.randint(0, resolution_coarse - render_images_scaled.shape[2]), random.randint(0, resolution_coarse - render_images_scaled.shape[3])) | |
render_images[:, :, left_up_coarse[0]: left_up_coarse[0] + render_images_scaled.shape[2], left_up_coarse[1]: left_up_coarse[1] + render_images_scaled.shape[3]] = render_images_scaled | |
images = torch.ones([images_scaled.shape[0], images_scaled.shape[1], resolution_fine, resolution_fine], device=self.device) | |
visibles = torch.ones([visibles_scaled.shape[0], visibles_scaled.shape[1], resolution_fine, resolution_fine], device=self.device) | |
left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse)) | |
images[:, :, left_up_fine[0]: left_up_fine[0] + images_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + images_scaled.shape[3]] = images_scaled | |
visibles[:, :, left_up_fine[0]: left_up_fine[0] + visibles_scaled.shape[2], left_up_fine[1]: left_up_fine[1] + visibles_scaled.shape[3]] = visibles_scaled | |
else: | |
left_up_coarse = (random.randint(0, render_images_scaled.shape[2] - resolution_coarse), random.randint(0, render_images_scaled.shape[3] - resolution_coarse)) | |
render_images = render_images_scaled[:, :, left_up_coarse[0]: left_up_coarse[0] + resolution_coarse, left_up_coarse[1]: left_up_coarse[1] + resolution_coarse] | |
left_up_fine = (int(left_up_coarse[0] * resolution_fine / resolution_coarse), int(left_up_coarse[1] * resolution_fine / resolution_coarse)) | |
images = images_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine] | |
visibles = visibles_scaled[:, :, left_up_fine[0]: left_up_fine[0] + resolution_fine, left_up_fine[1]: left_up_fine[1] + resolution_fine] | |
return render_images, images, visibles |