sat3density / model /base_model.py
venite's picture
initial
f670afc
raw
history blame
26.4 kB
import os
import torch
from abc import ABC, abstractmethod
import wandb
import options
import utils
from pytorch_msssim import ssim, SSIM
import numpy as np
import torchvision
from tqdm import tqdm
import lpips
from imaginaire.losses import FeatureMatchingLoss, GaussianKLLoss, PerceptualLoss,GANLoss
import cv2
from imaginaire.utils.trainer import get_scheduler
from .geometry_transform import render_sat
from model import geometry_transform
import csv
class BaseModel(ABC):
"""This class is an abstract base class (ABC) for models.
To create a subclass, you need to implement the following five functions:
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
-- <set_input>: unpack data from dataset and apply preprocessing.
-- <forward>: produce intermediate results.
-- <optimize_parameters>: calculate losses, gradients, and update network weights.
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
"""
def __init__(self, opt,wandb=None):
"""Initialize the BaseModel class.
Parameters:
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
When creating your custom class, you need to implement your own initialization.
In this function, you should first call <BaseModel.__init__(self, opt)>
Then, you need to define four lists:
-- self.loss_names (str list): specify the training losses that you want to plot and save.
-- self.model_names (str list): define networks used in our training.
"""
self.wandb = wandb
if opt.isTrain:
opt.save_dir =wandb.dir
options.save_options_file(opt,opt.save_dir)
self.opt = opt
self.device = "cpu" if opt.cpu or not torch.cuda.is_available() else "cuda:{}".format(opt.gpu)
# torch.backends.cudnn.benchmark = True
self.model_names = []
self.train_loader = None
self.val_loader = None
self.sty_loader = None
self.loss_fn_alex = lpips.LPIPS(net='alex',eval_mode=True).cuda()
if opt.task=='test':
self.loss_fn_sque = lpips.LPIPS(net='squeeze',eval_mode=True).cuda()
self.mseloss = torch.nn.MSELoss(True,True)
self.criteria = {}
self.weights = {}
if hasattr(opt.optim.loss_weight, 'GaussianKL'):
if opt.optim.loss_weight.GaussianKL:
self.criteria['GaussianKL'] = GaussianKLLoss()
self.weights['GaussianKL'] = opt.optim.loss_weight.GaussianKL
if hasattr(opt.optim.loss_weight, 'L1'):
if opt.optim.loss_weight.L1:
self.criteria['L1'] = torch.nn.L1Loss(True,True)
self.weights['L1'] = opt.optim.loss_weight.L1
if hasattr(opt.optim.loss_weight, 'L2'):
if opt.optim.loss_weight.L2:
self.criteria['L2'] = torch.nn.MSELoss(True,True)
self.weights['L2'] = opt.optim.loss_weight.L2
if hasattr(opt.optim.loss_weight, 'SSIM'):
if opt.optim.loss_weight.SSIM:
self.criteria['SSIM'] = SSIM(data_range =1., size_average=True, channel=3)
self.weights['SSIM'] = opt.optim.loss_weight.SSIM
if hasattr(opt.optim.loss_weight, 'Perceptual'):
if opt.optim.loss_weight.Perceptual:
self.criteria['Perceptual'] = \
PerceptualLoss(
network=opt.optim.perceptual_loss.mode,
layers=opt.optim.perceptual_loss.layers,
weights=opt.optim.perceptual_loss.weights).to(self.device)
self.weights['Perceptual'] = opt.optim.loss_weight.Perceptual
if hasattr(opt.optim.loss_weight, 'sky_inner'):
if opt.optim.loss_weight.sky_inner:
self.criteria['sky_inner'] = torch.nn.L1Loss(True,True)
self.weights['sky_inner'] = opt.optim.loss_weight.sky_inner
if hasattr(opt.optim.loss_weight, 'feature_matching'):
if opt.optim.loss_weight.feature_matching:
self.criteria['feature_matching'] = FeatureMatchingLoss()
self.weights['feature_matching'] = opt.optim.loss_weight.feature_matching
self.weights['GAN'] = opt.optim.loss_weight.GAN
self.criteria['GAN'] = GANLoss(gan_mode=opt.optim.gan_mode)
@staticmethod
def modify_commandline_options(parser, is_train):
"""Add new model-specific options, and rewrite default values for existing options.
Parameters:
parser -- original option parser
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
Returns:
the modified parser.
"""
return parser
@abstractmethod
def set_input(self, input):
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
Parameters:
input (dict): includes the data itself and its metadata information.
"""
pass
@abstractmethod
def forward(self):
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
pass
def save_checkpoint(self,ep=0,latest=False):
"""
save trained models.
Args:
ep (int, optional): model epochs. Defaults to 0.
latest (bool, optional): qhether it is the latest model. Defaults to False.
"""
ckpt_save_path = os.path.join(self.wandb.dir,'checkpoint')
if not os.path.exists(ckpt_save_path):
os.mkdir(ckpt_save_path)
utils.save_checkpoint(self,ep=ep,latest=latest,output_path=ckpt_save_path)
if not latest:
print("checkpoint saved: {0}, epoch {1} ".format(self.opt.name,ep))
def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
"""Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
key = keys[i]
if i + 1 == len(keys): # at the end, pointing to a parameter/buffer
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'running_mean' or key == 'running_var'):
if getattr(module, key) is None:
state_dict.pop('.'.join(keys))
if module.__class__.__name__.startswith('InstanceNorm') and \
(key == 'num_batches_tracked'):
state_dict.pop('.'.join(keys))
else:
self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
def setup_optimizer(self,opt):
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
self.optimizer_G = torch.optim.Adam(self.netG.parameters(), lr=opt.optim.lr_gen, betas=(opt.optim.beta1, 0.999),eps=1.e-7)
if opt.isTrain:
self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=opt.optim.lr_dis, betas=(opt.optim.beta1, 0.999))
if opt.optim.lr_policy:
self.sch_G = get_scheduler(opt.optim, self.optimizer_G)
self.sch_D = get_scheduler(opt.optim, self.optimizer_D)
def optimize_parameters(self,opt):
self.netG.train()
# update Discriminators
self.backward_D(opt) # calculate gradients for D
# update Generator
self.backward_G(opt) # calculate graidents for G
psnr1 = -10*self.mseloss(self.fake_B.detach(),self.real_B.detach()).log10().item()
ssim_ = ssim(self.real_B.detach().float(), self.fake_B.detach().float(),data_range=1.)
out_dict = {
"train_ssim": ssim_,
"train_psnr1": psnr1,
}
# adjust learning rates according to schedule
if opt.optim.lr_policy:
out_dict["lr_D"]=self.sch_D.get_lr()[0]
out_dict["lr_G"]=self.sch_G.get_lr()[0]
out_dict.update(self.loss)
out_dict.update(self.dis_losses)
self.wandb.log(out_dict)
def validation(self,opt):
"""Used for validation and test in Center Ground-View Synthesis setting
Args:
opt (_type_): option dict
"""
print(10*"*","validate",10*"*")
self.netG.eval()
# six image reconstruction metrics
psnr_val = []
ssim_val = []
lpips_ale_val = []
lpips_squ_val = []
rmse_val = []
sd_val = []
with torch.no_grad():
# set the sky of all images with predefined sky histogram.
if opt.sty_img:
for _,data in enumerate(self.sty_loader):
self.set_input(data)
self.style_temp=self.sky_histc
break
for _,data in enumerate(tqdm(self.val_loader,ncols=100)):
self.set_input(data)
# if true: use the sky of predefined image
# if false: use the sky of corresponding GT
if opt.sty_img:
self.sky_histc = self.style_temp
self.forward(opt)
rmse = torch.sqrt(self.mseloss(self.fake_B*255.,self.real_B*255.)).item()
sd = sd_func(self.real_B,self.fake_B)
rmse_val.append(rmse)
sd_val.append(sd)
psnr1 = -10*self.mseloss(self.fake_B,self.real_B).log10().item()
ssim_ = ssim(self.real_B, self.fake_B,data_range=1.).item()
lpips_ale = torch.mean(self.loss_fn_alex((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu()
if opt.task=='test':
lpips_sque = torch.mean(self.loss_fn_sque((self.real_B*2.)-1, (2.*self.fake_B)-1)).cpu()
lpips_squ_val.append(lpips_sque)
psnr_val.append(psnr1)
ssim_val.append(ssim_)
lpips_ale_val.append(lpips_ale)
if opt.task in ['vis_test']:
if not os.path.exists(opt.vis_dir):
os.mkdir(opt.vis_dir)
sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel'])
self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255.
sat_depth = (sat_depth/sat_depth.max())*255.
for i in range(len(self.fake_B)):
depth_save = cv2.applyColorMap(self.out_put['depth'][i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
depth_sat_save = cv2.applyColorMap(sat_depth[i].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
# cat generated ground images, GT ground images, predicted ground depth
torchvision.utils.save_image([self.fake_B[i].cpu(),self.real_B[i].cpu(),torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i])))
# cat GT satellite images, predicted satellite depth
torchvision.utils.save_image( [self.real_A[i].cpu() ,torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0])],os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg'))
# ground opacity
torchvision.utils.save_image([self.out_put['opacity'][i]] ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[i]).rsplit('.', 1)[0]+'_sat.jpg'))
psnr_avg = np.average(psnr_val)
ssim_avg = np.average(ssim_val)
lpips_ale_avg = np.average(lpips_ale_val)
if 'test' in opt.task:
lpips_squ_avg = np.average(lpips_squ_val)
rmse_avg = np.average(rmse_val)
sd_avg = np.average(sd_val)
if opt.task in ["train" , "Train"]:
out_dict = {
'val_psnr': psnr_avg,
'val_ssim': ssim_avg,
'val_lpips_ale':lpips_ale_avg,
'val_rmse':rmse_avg,
'val_sd':sd_avg
}
if opt.task=='test':
out_dict['val_lpips_squ'] = lpips_squ_avg
self.wandb.log(out_dict,commit=False)
else:
print(
{
'val_rmse':rmse_avg,
'val_ssim': ssim_avg,
'val_psnr': psnr_avg,
'val_sd':sd_avg,
'val_lpips_ale':lpips_ale_avg,
'val_lpips_squ':lpips_squ_avg,
}
)
with open('test_output.csv', mode='a', newline='') as csv_file:
writer = csv.writer(csv_file)
writer.writerow([rmse_avg, ssim_avg, psnr_avg, sd_avg, lpips_ale_avg, lpips_squ_avg])
def test_vid(self,opt):
"""Used for synthesis ground video
Args:
opt (_type_): option dict
"""
ckpt_list = os.listdir('wandb/')
for i in ckpt_list:
if opt.test_ckpt_path in i:
ckpt_path = i
ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
print('load success!')
self.netG.load_state_dict(ckpt,strict=True)
self.netG.eval()
print(10*"*","test_video",10*"*")
pixels = []
if os.path.exists('vis_video/pixels.csv'):
with open('vis_video/pixels.csv', 'r') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
x = float(row['x']) #x is
y = float(row['y'])
pixels.append((x, y))
else:
print('only render center point without vis_video/pixels.csv')
pixels = [(128,128)]
if opt.sty_img:
# inference with illumination from other images
for idx,data in enumerate(self.sty_loader):
self.set_input(data)
self.style_temp=self.sky_histc
break
with torch.no_grad():
for idx,data in enumerate(self.val_loader):
self.set_input(data)
if opt.sty_img:
self.sky_histc = self.style_temp
for i,(x,y) in enumerate(pixels):
opt.origin_H_W = [(y-128)/128 , (x-128)/128]
print(opt.origin_H_W)
self.forward(opt)
if not os.path.exists('vis_video'):
os.mkdir('vis_video')
# save voxel to visalize & satellite depth, works well on cvact
if i==0:
# pre-process for better visualize
volume_data = self.out_put.voxel.squeeze().cpu().numpy().transpose((1,2,0))
volume_data = np.clip(volume_data, None, 10)
import pyvista as pv
grid = pv.UniformGrid()
grid.dimensions = volume_data.shape
grid.spacing = (1, 1, 1)
grid.origin = (0, 0, 0)
grid.point_data['values'] = volume_data.flatten(order='F')
grid.save(os.path.join('vis_video',"volume_data.vtk") ) # vtk file could be visualized by ParaView app
sat_opacity,sat_depth = render_sat(opt,self.out_put['voxel'])
sat_depth = (2 - sat_depth)/(opt.data.max_height/15)*255.
depth_sat_save = cv2.applyColorMap(sat_depth[0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
torchvision.utils.save_image(torch.flip(torch.from_numpy(depth_sat_save).permute(2,0,1)/255.,[0]) ,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png','_satdepth.png'))
torchvision.utils.save_image( [self.real_A[0].cpu() ] ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_sat.png')))
torchvision.utils.save_image( [self.real_B[0].cpu() ] ,os.path.join('vis_video',os.path.basename(self.image_paths[0]).replace('.png','_pano.png')))
self.out_put['depth'] = (self.out_put['depth']/self.out_put['depth'].max())*255.
depth_save = cv2.applyColorMap(self.out_put['depth'][0].squeeze().cpu().numpy().astype(np.uint8), cv2.COLORMAP_TURBO)
depth_save = torch.flip(torch.from_numpy(depth_save).permute(2,0,1)/255.,[0])
save_img = self.out_put.pred[0].cpu()
name = '%05d' % int(i) + ".png"
torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
save_img = depth_save
name = '%05d' % int(i) + "_depth.png"
torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
# save_img = self.out_put.generator_inputs[0][:3,:,:]
# name = '%05d' % int(i) + "_color_project.png"
# torchvision.utils.save_image(save_img,os.path.join('vis_video',os.path.basename(self.image_paths[0])).replace('.png',name))
def test_interpolation(self,opt):
"""Used for test interpolation
Args:
opt (_type_): option dict
"""
ckpt_list = os.listdir('wandb/')
for i in ckpt_list:
if opt.test_ckpt_path in i:
ckpt_path = i
ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
print('load success!')
self.netG.load_state_dict(ckpt,strict=True)
self.netG.eval()
pixels = [(128,128)]
if opt.sty_img1:
for idx,data in enumerate(self.sty_loader1):
self.set_input(data)
self.style_temp1=self.sky_histc
break
if opt.sty_img2:
for idx,data in enumerate(self.sty_loader2):
self.set_input(data)
self.style_temp2=self.sky_histc
break
with torch.no_grad():
for idx,data in enumerate(self.val_loader):
self.set_input(data)
self.sky_histc1 = self.style_temp1
self.sky_histc2 = self.style_temp2
x,y = pixels[0]
opt.origin_H_W = [(y-128)/128 , (x-128)/128]
print(opt.origin_H_W)
estimated_height = self.netG.depth_model(self.real_A)
geo_outputs = geometry_transform.render(opt,self.real_A,estimated_height,self.netG.pano_direction,PE=self.netG.PE)
generator_inputs,opacity,depth = geo_outputs['rgb'],geo_outputs['opacity'],geo_outputs['depth']
if self.netG.gen_cfg.cat_opa:
generator_inputs = torch.cat((generator_inputs,opacity),dim=1)
if self.netG.gen_cfg.cat_depth:
generator_inputs = torch.cat((generator_inputs,depth),dim=1)
_, _, z1 = self.netG.style_encode(self.sky_histc1)
_, _, z2 = self.netG.style_encode(self.sky_histc2)
num_inter = 60
for i in range(num_inter):
z = z1 * (1-i/(num_inter-1)) + z2* (i/(num_inter-1))
z = self.netG.style_model(z)
output_RGB = self.netG.denoise_model(generator_inputs,z)
save_img = output_RGB.cpu()
name = 'img{:03d}.png'.format(i)
if not os.path.exists('vis_interpolation'):
os.mkdir('vis_interpolation')
torchvision.utils.save_image(save_img,os.path.join('vis_interpolation',name))
def test_speed(self,opt):
self.netG.eval()
random_input = torch.randn(1, 3, 256, 256).to(opt.device)
starter, ender = torch.cuda.Event(enable_timing=True), torch.cuda.Event(enable_timing=True)
iterations = 300
times = torch.zeros(iterations)
with torch.no_grad():
for _ in range(50):
_ = self.netG(random_input,None,opt)
for iter in range(iterations):
starter.record()
_ = self.netG(random_input,None,opt)
ender.record()
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender) # 计算时间
times[iter] = curr_time
# print(curr_time)
mean_time = times.mean().item()
print("Inference time: {:.6f}, FPS: {} ".format(mean_time, 1000/mean_time))
def test_sty(self,opt):
ckpt_list = os.listdir('wandb/')
for i in ckpt_list:
if opt.test_ckpt_path in i:
ckpt_path = i
ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
print('load success!')
self.netG.load_state_dict(ckpt,strict=True)
self.netG.eval()
print(10*"*","test_sty",10*"*")
self.netG.eval()
self.style_temp_list = []
with torch.no_grad():
num_val_loader = len(self.val_loader)
for i in range(num_val_loader):
for idx,data in enumerate(tqdm(self.val_loader,ncols=100)):
self.set_input(data)
if i==0:
self.style_temp_list.append(self.sky_histc)
name = '%05d' % int(idx)
torchvision.utils.save_image( [self.real_A[0].cpu() ] ,os.path.join(opt.vis_dir,os.path.basename(self.image_paths[0]).replace('.png',name+'_sat.png')))
self.sky_histc = self.style_temp_list[i]
self.forward(opt)
if not os.path.exists(opt.vis_dir):
os.mkdir(opt.vis_dir)
name = '%05d' % int(idx)+'_'+'%05d' % int(i)
name= name+ '.png'
torchvision.utils.save_image(self.fake_B[0].cpu(),os.path.join(opt.vis_dir, name))
def train(self,opt):
self.validation(opt)
for current_epoch in range(opt.max_epochs):
print(10*'-','current epoch is ',current_epoch,10*'-')
for idx,data in enumerate(tqdm(self.train_loader,ncols=100)):
self.set_input(data)
self.optimize_parameters(opt)
if idx%500==0 :
out_ing_dict = {
'train_input': wandb.Image(self.real_A[0].float()),
'train_pred_and_gt': wandb.Image(torch.cat([self.fake_B,self.real_B],2)[0].float()),
}
if hasattr(self.out_put, 'inter_RGB'):
out_ing_dict["train_inner_pred"] = wandb.Image(self.out_put.inter_RGB[0].float())
if opt.arch.gen.transform_mode in ['volum_rendering']:
out_ing_dict['train_inner_opacity'] = wandb.Image(self.out_put.opacity[0].float())
self.wandb.log(out_ing_dict,commit=False)
if opt.optim.lr_policy.iteration_mode:
self.sch_G.step()
self.sch_D.step()
if not opt.optim.lr_policy.iteration_mode:
self.sch_G.step()
self.sch_D.step()
self.validation(opt)
if current_epoch%5==0:
self.save_checkpoint(ep=current_epoch)
self.save_checkpoint(ep=current_epoch)
def test(self,opt):
ckpt_list = os.listdir('wandb/')
for i in ckpt_list:
if '.zip' not in i:
if opt.test_ckpt_path in i:
ckpt_path = i
ckpt = torch.load(os.path.join('wandb/',ckpt_path,'files/checkpoint/model.pth'))['netG']
print('load success!')
self.netG.load_state_dict(ckpt,strict=True)
# print(10*"*","validate",10*"*")
self.validation(opt)
print('if --task=vis_test,visible results will be saved,you can add "--vis_dir=xxx" to save in other dictionary',opt.vis_dir)
def _get_outputs(self, net_D_output, real=True):
r"""Return output values. Note that when the gan mode is relativistic.
It will do the difference before returning.
Args:
net_D_output (dict):
real_outputs (tensor): Real output values.
fake_outputs (tensor): Fake output values.
real (bool): Return real or fake.
"""
def _get_difference(a, b):
r"""Get difference between two lists of tensors or two tensors.
Args:
a: list of tensors or tensor
b: list of tensors or tensor
"""
out = list()
for x, y in zip(a, b):
if isinstance(x, list):
res = _get_difference(x, y)
else:
res = x - y
out.append(res)
return out
if real:
return net_D_output['real_outputs']
else:
return net_D_output['fake_outputs']
def sd_func(real, fake):
'''
ref: page 6 in https://arxiv.org/abs/1511.05440
'''
dgt1 = torch.abs(torch.diff(real,dim=-2))[:, :, 1:, 1:-1]
dgt2 = torch.abs(torch.diff(real, dim=-1))[:, :, 1:-1, 1:]
dpred1 = torch.abs(torch.diff(fake, dim=-2))[:, :, 1:, 1:-1]
dpred2 = torch.abs(torch.diff(fake, dim=-1))[:, :, 1:-1, 1:]
return 10*torch.log10(1.**2/torch.mean(torch.abs(dgt1+dgt2-dpred1-dpred2))).cpu().item()