pengc02's picture
all
ec9a6bc
raw
history blame
7.91 kB
import torch
import numpy as np
from einops import rearrange
class Calibrator():
def __init__(self, dataset, bfm, camera, recorder):
self.dataset = dataset
self.bfm = bfm
self.camera = camera
self.recorder = recorder
self.device = torch.device('cuda:0')
self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}])
self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}])
def calibrate(self):
landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item()
landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device)
extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device)
intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device)
extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y')
intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y')
pprev_loss = 1e8
prev_loss = 1e8
for i in range(100000000):
self.optimizer1.zero_grad()
_, landmarks_3d = self.bfm()
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1)
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y')
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics)
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0])
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean()
reg_loss = self.bfm.reg_loss(5e-6, 1e-6)
loss = pro_loss + reg_loss
loss.backward()
self.optimizer1.step()
if abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7:
if i % 100 == 0:
print(pro_loss.item(), reg_loss.item())
break
else:
pprev_loss = prev_loss
prev_loss = loss.item()
if i % 100 == 0:
print(pro_loss.item(), reg_loss.item())
for i in range(100000000):
self.optimizer2.zero_grad()
_, landmarks_3d = self.bfm()
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1)
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y')
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics)
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0])
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean()
reg_loss = self.bfm.reg_loss(5e-6, 1e-6)
loss = pro_loss + reg_loss
loss.backward()
self.optimizer2.step()
if abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10:
if i % 100 == 0:
print(pro_loss.item(), reg_loss.item())
break
else:
pprev_loss = prev_loss
prev_loss = loss.item()
if i % 100 == 0:
print(pro_loss.item(), reg_loss.item())
log = {
'eids': eids,
'landmarks_gt': landmarks_gt,
'landmarks_2d': landmarks_2d.detach(),
'bfm': self.bfm,
'intrinsics': intrinsics0,
'extrinsics': extrinsics0
}
self.recorder.log(log)
def project(self, points_3d, intrinsic, extrinsic):
points_3d = points_3d.permute(0,2,1)
calibrations = torch.bmm(intrinsic, extrinsic)
points_2d = self.camera.perspective(points_3d, calibrations)
points_2d = points_2d.permute(0,2,1)
return points_2d
class CalibratorSingleView():
def __init__(self, dataset, bfm, camera, recorder):
self.dataset = dataset
self.bfm = bfm
self.camera = camera
self.recorder = recorder
self.device = torch.device('cuda:0')
self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}])
self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}])
def calibrate(self):
landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item()
landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device)
extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device)
intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device)
extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y')
intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y')
pprev_loss = 1e8
prev_loss = 1e8
for i in range(3000):
self.optimizer1.zero_grad()
_, landmarks_3d = self.bfm()
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1)
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y')
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics)
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0])
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean()
reg_loss = self.bfm.reg_loss(5e-6, 1e-6)
loss = pro_loss + reg_loss
loss.backward()
self.optimizer1.step()
if False: #abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7:
print(prev_loss, 'optimization ends')
break
else:
pprev_loss = prev_loss
prev_loss = loss.item()
if i % 100 == 0:
print(i, prev_loss)
for i in range(1000):
self.optimizer2.zero_grad()
_, landmarks_3d = self.bfm()
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1)
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y')
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics)
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0])
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean()
reg_loss = self.bfm.reg_loss(5e-6, 1e-6) + self.bfm.temporal_smooth_loss(3e-5)
loss = pro_loss + reg_loss
loss.backward()
self.optimizer2.step()
if False: #abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10:
print(prev_loss, 'optimization ends')
break
else:
pprev_loss = prev_loss
prev_loss = loss.item()
if i % 100 == 0:
print(i, prev_loss)
log = {
'eids': eids,
'landmarks_gt': landmarks_gt,
'landmarks_2d': landmarks_2d.detach(),
'bfm': self.bfm,
'intrinsics': intrinsics0,
'extrinsics': extrinsics0
}
self.recorder.log(log)
def project(self, points_3d, intrinsic, extrinsic):
points_3d = points_3d.permute(0,2,1)
calibrations = torch.bmm(intrinsic, extrinsic)
points_2d = self.camera.perspective(points_3d, calibrations)
points_2d = points_2d.permute(0,2,1)
return points_2d