import numpy as np import torch,math from PIL import Image import torchvision from easydict import EasyDict as edict def position_produce(opt): depth_channel = opt.arch.gen.depth_arch.output_nc if opt.optim.ground_prior: depth_channel = depth_channel+1 z_ = torch.arange(depth_channel)/depth_channel x_ = torch.arange([1])/[1] y_ = torch.arange([0])/[0] Z,X,Y = torch.meshgrid(z_,x_,y_) input =[...,None],X[...,None],Y[...,None]),dim=-1).to(opt.device) pos = positional_encoding(opt,input) pos = pos.permute(3,0,1,2) return pos def positional_encoding(opt,input): # [B,...,N] shape = input.shape freq = 2**torch.arange(opt.arch.gen.PE_channel,dtype=torch.float32,device=opt.device)*np.pi # [L] spectrum = input[...,None]*freq # [B,...,N,L] sin,cos = spectrum.sin(),spectrum.cos() # [B,...,N,L] input_enc = torch.stack([sin,cos],dim=-2) # [B,...,N,2,L] input_enc = input_enc.view(*shape[:-1],-1) # [B,...,2NL] return input_enc def get_original_coord(opt): ''' pano_direction [X,Y,Z] x right,y up,z out ''' W,H = _y = np.repeat(np.array(range(W)).reshape(1,W), H, axis=0) _x = np.repeat(np.array(range(H)).reshape(1,H), W, axis=0).T if in ['CVACT_Shi', 'CVACT', 'CVACThalf']: _theta = (1 - 2 * (_x) / H) * np.pi/2 # latitude elif in ['CVUSA']: _theta = (1 - 2 * (_x) / H) * np.pi/4 # _phi = math.pi* ( 1 -2* (_y)/W ) # longtitude _phi = math.pi*( - 0.5 - 2* (_y)/W ) axis0 = (np.cos(_theta)*np.cos(_phi)).reshape(H, W, 1) axis1 = np.sin(_theta).reshape(H, W, 1) axis2 = (-np.cos(_theta)*np.sin(_phi)).reshape(H, W, 1) pano_direction = np.concatenate((axis0, axis1, axis2), axis=2) return pano_direction def render(opt,feature,voxel,pano_direction,PE=None): ''' render ground images from ssatellite images feature: B,C,H_sat,W_sat feature or a input RGB voxel: B,N,H_sat,W_sat density of each grid PE: whether add position encoding , default is None pano_direction: pano ray direction by their definition ''' # pano_W,pano_H = sat_W,sat_H = BS = feature.size(0) ##### get origin, sample point ,depth if =='CVACT_Shi': origin_height=2 ## the height of photo taken in real world scale realworld_scale = 30 ## the real world scale corresponding to [-1,1] regular cooridinate elif == 'CVUSA': origin_height=2 realworld_scale = 55 else: assert Exception('Not implement yet') assert sat_W==sat_H pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld if sample_total_length = else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(**2))/pixel_resolution))/(sat_W/2) origin_z = torch.ones([BS,1])*(-1+(origin_height/(realworld_scale/2))) ### -1 is the loweast position in regular cooridinate ##### origin_z: which can be definition by origin height if opt.origin_H_W is None: ### origin_H_W is the photo taken space in regular coordinate origin_H,origin_w = torch.zeros([BS,1]),torch.zeros([BS,1]) else: origin_H,origin_w = torch.ones([BS,1])*opt.origin_H_W[0],torch.ones([BS,1])*opt.origin_H_W[1] origin =[origin_w,origin_z,origin_H],dim=1).to(opt.device)[:,None,None,:] ## w,z,h, samiliar to NERF coordinate definition sample_len = ((torch.arange(*(sample_total_length/ ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number origin = origin[...,None] pano_direction = pano_direction[...,None] ### the direction has been normalized depth = sample_len[None,None,None,None,:] sample_point = origin + pano_direction * depth #0.0000],-0.8667],0.0000 w,z,h # x points right, y points up, z points backwards scene nerf # ray_depth = sample_point-origin if opt.optim.ground_prior: voxel =[torch.ones(voxel.size(0),1,voxel.size(2),voxel.size(3),device=opt.device)*1000,voxel],1) # voxel[:,0,:,:] = 100 N = voxel.size(1) voxel_low = -1 voxel_max = -1 + ### voxel highest space in normal space grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space grid = grid.float() ## [1, 300, 256, 512, 3] color_input = feature.unsqueeze(2).repeat(1, 1, N, 1, 1) alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) color_grid = torch.nn.functional.grid_sample(color_input, grid) if PE is not None: PE_grid = torch.nn.functional.grid_sample(PE[None,...], grid[:1,...]) color_grid =[color_grid,PE_grid.repeat(BS, 1, 1, 1, 1)],dim=1) depth_sample = depth.permute(0,1,2,4,3).view(1,-1,,1) feature_size = color_grid.size(1) color_grid = color_grid.permute(0,3,4,2,1).view(BS,-1,,feature_size) alpha_grid = alpha_grid.permute(0,3,4,2,1).view(BS,-1, intv = sample_total_length/ output = composite(opt, rgb_samples=color_grid,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) output['voxel'] = voxel return output def composite(opt,rgb_samples,density_samples,depth_samples,intv): """generate 2d ground images according to ray Args: opt (_type_): option dict rgb_samples (_type_): rgb (sampled from satellite image) belongs to the ray which start from the ground camera to world density_samples (_type_): density (sampled from the predicted voxel of satellite image) belongs to the ray which start from the ground camera to world depth_samples (_type_): depth of the ray which start from the ground camera to world intv (_type_): interval of the ray's depth which start from the ground camera to world Returns: 2d ground images (rgd, opacity, and depth) """ sigma_delta = density_samples*intv # [B,HW,N] alpha = 1-(-sigma_delta).exp_() # [B,HW,N] T = ([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] prob = (T*alpha)[...,None] # [B,HW,N,1] # integrate RGB and depth weighted by probability depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] rgb = (rgb_samples*prob).sum(dim=2) # [B,HW,3] opacity = prob.sum(dim=2) # [B,HW,1] depth = depth.permute(0,2,1).view(depth.size(0),-1,[1],[0]) rgb = rgb.permute(0,2,1).view(rgb.size(0),-1,[1],[0]) opacity = opacity.view(opacity.size(0),1,[1],[0]) return {'rgb':rgb,'opacity':opacity,'depth':depth} def get_sat_ori(opt): W,H = y_range = (torch.arange(H,dtype=torch.float32,)+0.5)/(0.5*H)-1 x_range = (torch.arange(W,dtype=torch.float32,)+0.5)/(0.5*H)-1 Y,X = torch.meshgrid(y_range,x_range) Z = torch.ones_like(Y) xy_grid = torch.stack([X,Z,Y],dim=-1)[None,:,:] return xy_grid def render_sat(opt,voxel): ''' voxel: voxel has been processed ''' # pano_W,pano_H = sat_W,sat_H = sat_ori = get_sat_ori(opt) sat_dir = torch.tensor([0,-1,0])[None,None,None,:] ##### get origin, sample point ,depth if =='CVACT_Shi': origin_height=2 realworld_scale = 30 elif == 'CVUSA': origin_height=2 realworld_scale = 55 else: assert Exception('Not implement yet') pixel_resolution = realworld_scale/sat_W #### pixel resolution of satellite image in realworld # if # sample_total_length = # else: sample_total_length = (int(max(np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(2)**2), \ # np.sqrt((realworld_scale/2)**2+(realworld_scale/2)**2+(**2))/pixel_resolution))/(sat_W/2) sample_total_length = 2 # #### sample_total_length: it can be definition in future, which is the farest length between sample point and original ponit # assert sat_W==sat_H origin = ## w,z,h, samiliar to NERF coordinate definition sample_len = ((torch.arange(*(sample_total_length/ ### sample_len: For sample distance is fixed, so we can easily calculate sample len along a way by max length and sample number origin = origin[...,None].to(opt.device) direction = sat_dir[...,None].to(opt.device) ### the direction has been normalized depth = sample_len[None,None,None,None,:] sample_point = origin + direction * depth #0.0000],-0.8667],0.0000 w,z,h N = voxel.size(1) voxel_low = -1 voxel_max = -1 + ### voxel highest space in normal space # axis_voxel = (torch.arange(N)/N) * (voxel_max-voxel_low) +voxel_low grid = sample_point.permute(0,4,1,2,3)[...,[0,2,1]] ### BS,NUM_point,W,H,3 grid[...,2] = ((grid[...,2]-voxel_low)/(voxel_max-voxel_low))*2-1 ### grid_space change to sample space by scale the z space grid = grid.float() ## [1, 300, 256, 512, 3] alpha_grid = torch.nn.functional.grid_sample(voxel.unsqueeze(1), grid) depth_sample = depth.permute(0,1,2,4,3).view(1,-1,,1) alpha_grid = alpha_grid.permute(0,3,4,2,1).view(opt.batch_size,-1, # color_grid = torch.flip(color_grid,[2]) # alpha_grid = torch.flip(alpha_grid,[2]) intv = sample_total_length/ output = composite_sat(opt,density_samples=alpha_grid,depth_samples=depth_sample,intv = intv) return output['opacity'],output['depth'] def composite_sat(opt,density_samples,depth_samples,intv): sigma_delta = density_samples*intv # [B,HW,N] alpha = 1-(-sigma_delta).exp_() # [B,HW,N] T = ([torch.zeros_like(sigma_delta[...,:1]),sigma_delta[...,:-1]],dim=2).cumsum(dim=2)) .exp_() # [B,HW,N] prob = (T*alpha)[...,None] # [B,HW,N,1] depth = (depth_samples*prob).sum(dim=2) # [B,HW,1] opacity = prob.sum(dim=2) # [B,HW,1] depth = depth.permute(0,2,1).view(depth.size(0),-1,[1],[0]) opacity = opacity.view(opacity.size(0),1,[1],[0]) # return rgb,depth,opacity,prob # [B,HW,K] return {'opacity':opacity,'depth':depth} if __name__ == '__main__': # test_demo opt=edict() opt.device = 'cuda' = edict() = [512,256] = [256,256] = 'CVACT_Shi' = 20 = 300 opt.arch = edict() opt.optim = edict() opt.optim.ground_prior = False opt.arch.gen.transform_mode = 'volum_rendering' # opt.arch.gen.transform_mode = 'proj_like_radus' BS = 1 = 1 sat_name = './CVACT/satview_correct/__-DFIFxvZBCn1873qkqXA_satView_polish.png' a = a = np.array(a).astype(np.float32) a = torch.from_numpy(a) a = a.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. pano = sat_name.replace('satview_correct','streetview').replace('_satView_polish','_grdView') pano = np.array( pano = torch.from_numpy(pano) pano = pano.permute(2, 0, 1).unsqueeze(0).to(opt.device).repeat(BS,1,1,1)/255. voxel=torch.zeros([BS, 65, 256, 256]).to(opt.device) pano_direction = torch.from_numpy(get_original_coord(opt)).unsqueeze(0).to(opt.device) import time star = time.time() with torch.autograd.profiler.profile(enabled=True, use_cuda=True, record_shapes=False, profile_memory=False) as prof: rgb,opacity =render(opt,a,voxel,pano_direction) print(prof.table()) print(time.time()-star) torchvision.utils.save_image([rgb,pano],2), opt.arch.gen.transform_mode + '.png') print( opt.arch.gen.transform_mode + '.png') torchvision.utils.save_image(opacity, 'opa.png')