import torch import numpy as np from einops import rearrange class Calibrator(): def __init__(self, dataset, bfm, camera, recorder): self.dataset = dataset self.bfm = bfm self.camera = camera self.recorder = recorder self.device = torch.device('cuda:0') self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}]) self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}]) def calibrate(self): landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item() landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device) extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device) intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device) extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y') intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y') pprev_loss = 1e8 prev_loss = 1e8 for i in range(100000000): self.optimizer1.zero_grad() _, landmarks_3d = self.bfm() landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() reg_loss = self.bfm.reg_loss(5e-6, 1e-6) loss = pro_loss + reg_loss loss.backward() self.optimizer1.step() if abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7: if i % 100 == 0: print(pro_loss.item(), reg_loss.item()) break else: pprev_loss = prev_loss prev_loss = loss.item() if i % 100 == 0: print(pro_loss.item(), reg_loss.item()) for i in range(100000000): self.optimizer2.zero_grad() _, landmarks_3d = self.bfm() landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() reg_loss = self.bfm.reg_loss(5e-6, 1e-6) loss = pro_loss + reg_loss loss.backward() self.optimizer2.step() if abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10: if i % 100 == 0: print(pro_loss.item(), reg_loss.item()) break else: pprev_loss = prev_loss prev_loss = loss.item() if i % 100 == 0: print(pro_loss.item(), reg_loss.item()) log = { 'eids': eids, 'landmarks_gt': landmarks_gt, 'landmarks_2d': landmarks_2d.detach(), 'bfm': self.bfm, 'intrinsics': intrinsics0, 'extrinsics': extrinsics0 } self.recorder.log(log) def project(self, points_3d, intrinsic, extrinsic): points_3d = points_3d.permute(0,2,1) calibrations = torch.bmm(intrinsic, extrinsic) points_2d = self.camera.perspective(points_3d, calibrations) points_2d = points_2d.permute(0,2,1) return points_2d class CalibratorSingleView(): def __init__(self, dataset, bfm, camera, recorder): self.dataset = dataset self.bfm = bfm self.camera = camera self.recorder = recorder self.device = torch.device('cuda:0') self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}]) self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}]) def calibrate(self): landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item() landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device) extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device) intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device) extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y') intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y') pprev_loss = 1e8 prev_loss = 1e8 for i in range(3000): self.optimizer1.zero_grad() _, landmarks_3d = self.bfm() landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() reg_loss = self.bfm.reg_loss(5e-6, 1e-6) loss = pro_loss + reg_loss loss.backward() self.optimizer1.step() if False: #abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7: print(prev_loss, 'optimization ends') break else: pprev_loss = prev_loss prev_loss = loss.item() if i % 100 == 0: print(i, prev_loss) for i in range(1000): self.optimizer2.zero_grad() _, landmarks_3d = self.bfm() landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() reg_loss = self.bfm.reg_loss(5e-6, 1e-6) + self.bfm.temporal_smooth_loss(3e-5) loss = pro_loss + reg_loss loss.backward() self.optimizer2.step() if False: #abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10: print(prev_loss, 'optimization ends') break else: pprev_loss = prev_loss prev_loss = loss.item() if i % 100 == 0: print(i, prev_loss) log = { 'eids': eids, 'landmarks_gt': landmarks_gt, 'landmarks_2d': landmarks_2d.detach(), 'bfm': self.bfm, 'intrinsics': intrinsics0, 'extrinsics': extrinsics0 } self.recorder.log(log) def project(self, points_3d, intrinsic, extrinsic): points_3d = points_3d.permute(0,2,1) calibrations = torch.bmm(intrinsic, extrinsic) points_2d = self.camera.perspective(points_3d, calibrations) points_2d = points_2d.permute(0,2,1) return points_2d