Spaces:
Sleeping
Sleeping
import torch | |
import numpy as np | |
from einops import rearrange | |
class Calibrator(): | |
def __init__(self, dataset, bfm, camera, recorder): | |
self.dataset = dataset | |
self.bfm = bfm | |
self.camera = camera | |
self.recorder = recorder | |
self.device = torch.device('cuda:0') | |
self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}]) | |
self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}]) | |
def calibrate(self): | |
landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item() | |
landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device) | |
extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device) | |
intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device) | |
extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y') | |
intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y') | |
pprev_loss = 1e8 | |
prev_loss = 1e8 | |
for i in range(100000000): | |
self.optimizer1.zero_grad() | |
_, landmarks_3d = self.bfm() | |
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) | |
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') | |
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) | |
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) | |
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() | |
reg_loss = self.bfm.reg_loss(5e-6, 1e-6) | |
loss = pro_loss + reg_loss | |
loss.backward() | |
self.optimizer1.step() | |
if abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7: | |
if i % 100 == 0: | |
print(pro_loss.item(), reg_loss.item()) | |
break | |
else: | |
pprev_loss = prev_loss | |
prev_loss = loss.item() | |
if i % 100 == 0: | |
print(pro_loss.item(), reg_loss.item()) | |
for i in range(100000000): | |
self.optimizer2.zero_grad() | |
_, landmarks_3d = self.bfm() | |
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) | |
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') | |
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) | |
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) | |
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() | |
reg_loss = self.bfm.reg_loss(5e-6, 1e-6) | |
loss = pro_loss + reg_loss | |
loss.backward() | |
self.optimizer2.step() | |
if abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10: | |
if i % 100 == 0: | |
print(pro_loss.item(), reg_loss.item()) | |
break | |
else: | |
pprev_loss = prev_loss | |
prev_loss = loss.item() | |
if i % 100 == 0: | |
print(pro_loss.item(), reg_loss.item()) | |
log = { | |
'eids': eids, | |
'landmarks_gt': landmarks_gt, | |
'landmarks_2d': landmarks_2d.detach(), | |
'bfm': self.bfm, | |
'intrinsics': intrinsics0, | |
'extrinsics': extrinsics0 | |
} | |
self.recorder.log(log) | |
def project(self, points_3d, intrinsic, extrinsic): | |
points_3d = points_3d.permute(0,2,1) | |
calibrations = torch.bmm(intrinsic, extrinsic) | |
points_2d = self.camera.perspective(points_3d, calibrations) | |
points_2d = points_2d.permute(0,2,1) | |
return points_2d | |
class CalibratorSingleView(): | |
def __init__(self, dataset, bfm, camera, recorder): | |
self.dataset = dataset | |
self.bfm = bfm | |
self.camera = camera | |
self.recorder = recorder | |
self.device = torch.device('cuda:0') | |
self.optimizer1 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-2}]) | |
self.optimizer2 = torch.optim.Adam([{'params' : self.bfm.parameters(), 'lr' : 1e-3}]) | |
def calibrate(self): | |
landmarks_gt, extrinsics0, intrinsics0, eids = self.dataset.get_item() | |
landmarks_gt = torch.from_numpy(landmarks_gt).float().to(self.device) | |
extrinsics0 = torch.from_numpy(extrinsics0).float().to(self.device) | |
intrinsics0 = torch.from_numpy(intrinsics0).float().to(self.device) | |
extrinsics = rearrange(extrinsics0, 'b v x y -> (b v) x y') | |
intrinsics = rearrange(intrinsics0, 'b v x y -> (b v) x y') | |
pprev_loss = 1e8 | |
prev_loss = 1e8 | |
for i in range(3000): | |
self.optimizer1.zero_grad() | |
_, landmarks_3d = self.bfm() | |
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) | |
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') | |
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) | |
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) | |
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() | |
reg_loss = self.bfm.reg_loss(5e-6, 1e-6) | |
loss = pro_loss + reg_loss | |
loss.backward() | |
self.optimizer1.step() | |
if False: #abs(loss.item() - prev_loss) < 1e-8 and abs(loss.item() - pprev_loss) < 1e-7: | |
print(prev_loss, 'optimization ends') | |
break | |
else: | |
pprev_loss = prev_loss | |
prev_loss = loss.item() | |
if i % 100 == 0: | |
print(i, prev_loss) | |
for i in range(1000): | |
self.optimizer2.zero_grad() | |
_, landmarks_3d = self.bfm() | |
landmarks_3d = landmarks_3d.unsqueeze(1).repeat(1, landmarks_gt.shape[1], 1, 1) | |
landmarks_3d = rearrange(landmarks_3d, 'b v x y -> (b v) x y') | |
landmarks_2d = self.project(landmarks_3d, intrinsics, extrinsics) | |
landmarks_2d = rearrange(landmarks_2d, '(b v) x y -> b v x y', b=landmarks_gt.shape[0]) | |
pro_loss = (((landmarks_2d / self.camera.image_size - landmarks_gt[:, :, :, 0:2] / self.camera.image_size) * landmarks_gt[:, :, :, 2:3]) ** 2).sum(-1).sum(-2).mean() | |
reg_loss = self.bfm.reg_loss(5e-6, 1e-6) + self.bfm.temporal_smooth_loss(3e-5) | |
loss = pro_loss + reg_loss | |
loss.backward() | |
self.optimizer2.step() | |
if False: #abs(loss.item() - prev_loss) < 1e-11 and abs(loss.item() - pprev_loss) < 1e-10: | |
print(prev_loss, 'optimization ends') | |
break | |
else: | |
pprev_loss = prev_loss | |
prev_loss = loss.item() | |
if i % 100 == 0: | |
print(i, prev_loss) | |
log = { | |
'eids': eids, | |
'landmarks_gt': landmarks_gt, | |
'landmarks_2d': landmarks_2d.detach(), | |
'bfm': self.bfm, | |
'intrinsics': intrinsics0, | |
'extrinsics': extrinsics0 | |
} | |
self.recorder.log(log) | |
def project(self, points_3d, intrinsic, extrinsic): | |
points_3d = points_3d.permute(0,2,1) | |
calibrations = torch.bmm(intrinsic, extrinsic) | |
points_2d = self.camera.perspective(points_3d, calibrations) | |
points_2d = points_2d.permute(0,2,1) | |
return points_2d | |