baglada's picture
Duplicate from Epoching/3D_Photo_Inpainting
d8816ee
import os
import numpy as np
try:
import cynetworkx as netx
except ImportError:
import networkx as netx
import matplotlib.pyplot as plt
from functools import partial
from vispy import scene, io
from vispy.scene import visuals
from vispy.visuals.filters import Alpha
import cv2
from moviepy.editor import ImageSequenceClip
from skimage.transform import resize
import time
import copy
import torch
import os
from utils import path_planning, open_small_mask, clean_far_edge, refine_depth_around_edge
from utils import refine_color_around_edge, filter_irrelevant_edge_new, require_depth_edge, clean_far_edge_new
from utils import create_placeholder, refresh_node, find_largest_rect
from mesh_tools import get_depth_from_maps, get_map_from_ccs, get_edge_from_nodes, get_depth_from_nodes, get_rgb_from_nodes, crop_maps_by_size, convert2tensor, recursive_add_edge, update_info, filter_edge, relabel_node, depth_inpainting
from mesh_tools import refresh_bord_depth, enlarge_border, fill_dummy_bord, extrapolate, fill_missing_node, incomplete_node, get_valid_size, dilate_valid_size, size_operation
import transforms3d
import random
from functools import reduce
def create_mesh(depth, image, int_mtx, config):
H, W, C = image.shape
ext_H, ext_W = H + 2 * config['extrapolation_thickness'], W + 2 * config['extrapolation_thickness']
LDI = netx.Graph(H=ext_H, W=ext_W, noext_H=H, noext_W=W, cam_param=int_mtx)
xy2depth = {}
int_mtx_pix = int_mtx * np.array([[W], [H], [1.]])
LDI.graph['cam_param_pix'], LDI.graph['cam_param_pix_inv'] = int_mtx_pix, np.linalg.inv(int_mtx_pix)
disp = 1. / (-depth)
LDI.graph['hoffset'], LDI.graph['woffset'] = config['extrapolation_thickness'], config['extrapolation_thickness']
LDI.graph['bord_up'], LDI.graph['bord_down'] = LDI.graph['hoffset'] + 0, LDI.graph['hoffset'] + H
LDI.graph['bord_left'], LDI.graph['bord_right'] = LDI.graph['woffset'] + 0, LDI.graph['woffset'] + W
for idx in range(H):
for idy in range(W):
x, y = idx + LDI.graph['hoffset'], idy + LDI.graph['woffset']
LDI.add_node((x, y, -depth[idx, idy]),
color=image[idx, idy],
disp=disp[idx, idy],
synthesis=False,
cc_id=set())
xy2depth[(x, y)] = [-depth[idx, idy]]
for x, y, d in LDI.nodes:
two_nes = [ne for ne in [(x+1, y), (x, y+1)] if ne[0] < LDI.graph['bord_down'] and ne[1] < LDI.graph['bord_right']]
[LDI.add_edge((ne[0], ne[1], xy2depth[ne][0]), (x, y, d)) for ne in two_nes]
LDI = calculate_fov(LDI)
image = np.pad(image,
pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
(config['extrapolation_thickness'], config['extrapolation_thickness']),
(0, 0)),
mode='constant')
depth = np.pad(depth,
pad_width=((config['extrapolation_thickness'], config['extrapolation_thickness']),
(config['extrapolation_thickness'], config['extrapolation_thickness'])),
mode='constant')
return LDI, xy2depth, image, depth
def tear_edges(mesh, threshold = 0.00025, xy2depth=None):
remove_edge_list = []
remove_horizon, remove_vertical = np.zeros((2, mesh.graph['H'], mesh.graph['W']))
mesh_nodes = mesh.nodes
for edge in mesh.edges:
if abs(mesh_nodes[edge[0]]['disp'] - mesh_nodes[edge[1]]['disp']) > threshold:
remove_edge_list.append((edge[0], edge[1]))
near, far = edge if abs(edge[0][2]) < abs(edge[1][2]) else edge[::-1]
mesh_nodes[far]['near'] = [] if mesh_nodes[far].get('near') is None else mesh_nodes[far]['near'].append(near)
mesh_nodes[near]['far'] = [] if mesh_nodes[near].get('far') is None else mesh_nodes[near]['far'].append(far)
if near[0] == far[0]:
remove_horizon[near[0], np.minimum(near[1], far[1])] = 1
elif near[1] == far[1]:
remove_vertical[np.minimum(near[0], far[0]), near[1]] = 1
mesh.remove_edges_from(remove_edge_list)
remove_edge_list = []
dang_horizon = np.where(np.roll(remove_horizon, 1, 0) + np.roll(remove_horizon, -1, 0) - remove_horizon == 2)
dang_vertical = np.where(np.roll(remove_vertical, 1, 1) + np.roll(remove_vertical, -1, 1) - remove_vertical == 2)
horizon_condition = lambda x, y: mesh.graph['bord_up'] + 1 <= x < mesh.graph['bord_down'] - 1
vertical_condition = lambda x, y: mesh.graph['bord_left'] + 1 <= y < mesh.graph['bord_right'] - 1
prjto3d = lambda x, y: (x, y, xy2depth[(x, y)][0])
node_existence = lambda x, y: mesh.has_node(prjto3d(x, y))
for x, y in zip(dang_horizon[0], dang_horizon[1]):
if horizon_condition(x, y) and node_existence(x, y) and node_existence(x, y+1):
remove_edge_list.append((prjto3d(x, y), prjto3d(x, y+1)))
for x, y in zip(dang_vertical[0], dang_vertical[1]):
if vertical_condition(x, y) and node_existence(x, y) and node_existence(x+1, y):
remove_edge_list.append((prjto3d(x, y), prjto3d(x+1, y)))
mesh.remove_edges_from(remove_edge_list)
return mesh
def calculate_fov(mesh):
k = mesh.graph['cam_param']
mesh.graph['hFov'] = 2 * np.arctan(1. / (2*k[0, 0]))
mesh.graph['vFov'] = 2 * np.arctan(1. / (2*k[1, 1]))
mesh.graph['aspect'] = mesh.graph['noext_H'] / mesh.graph['noext_W']
return mesh
def calculate_fov_FB(mesh):
mesh.graph['aspect'] = mesh.graph['H'] / mesh.graph['W']
if mesh.graph['H'] > mesh.graph['W']:
mesh.graph['hFov'] = 0.508015513
half_short = np.tan(mesh.graph['hFov']/2.0)
half_long = half_short * mesh.graph['aspect']
mesh.graph['vFov'] = 2.0 * np.arctan(half_long)
else:
mesh.graph['vFov'] = 0.508015513
half_short = np.tan(mesh.graph['vFov']/2.0)
half_long = half_short / mesh.graph['aspect']
mesh.graph['hFov'] = 2.0 * np.arctan(half_long)
return mesh
def reproject_3d_int_detail(sx, sy, z, k_00, k_02, k_11, k_12, w_offset, h_offset):
abs_z = abs(z)
return [abs_z * ((sy+0.5-w_offset) * k_00 + k_02), abs_z * ((sx+0.5-h_offset) * k_11 + k_12), abs_z]
def reproject_3d_int_detail_FB(sx, sy, z, w_offset, h_offset, mesh):
if mesh.graph.get('tan_hFov') is None:
mesh.graph['tan_hFov'] = np.tan(mesh.graph['hFov'] / 2.)
if mesh.graph.get('tan_vFov') is None:
mesh.graph['tan_vFov'] = np.tan(mesh.graph['vFov'] / 2.)
ray = np.array([(-1. + 2. * ((sy+0.5-w_offset)/(mesh.graph['W'] - 1))) * mesh.graph['tan_hFov'],
(1. - 2. * (sx+0.5-h_offset)/(mesh.graph['H'] - 1)) * mesh.graph['tan_vFov'],
-1])
point_3d = ray * np.abs(z)
return point_3d
def reproject_3d_int(sx, sy, z, mesh):
k = mesh.graph['cam_param_pix_inv'].copy()
if k[0, 2] > 0:
k = np.linalg.inv(k)
ray = np.dot(k, np.array([sy-mesh.graph['woffset'], sx-mesh.graph['hoffset'], 1]).reshape(3, 1))
point_3d = ray * np.abs(z)
point_3d = point_3d.flatten()
return point_3d
def generate_init_node(mesh, config, min_node_in_cc):
mesh_nodes = mesh.nodes
info_on_pix = {}
ccs = sorted(netx.connected_components(mesh), key = len, reverse=True)
remove_nodes = []
for cc in ccs:
remove_flag = True if len(cc) < min_node_in_cc else False
if remove_flag is False:
for (nx, ny, nd) in cc:
info_on_pix[(nx, ny)] = [{'depth':nd,
'color':mesh_nodes[(nx, ny, nd)]['color'],
'synthesis':False,
'disp':mesh_nodes[(nx, ny, nd)]['disp']}]
else:
[remove_nodes.append((nx, ny, nd)) for (nx, ny, nd) in cc]
for node in remove_nodes:
far_nodes = [] if mesh_nodes[node].get('far') is None else mesh_nodes[node]['far']
for far_node in far_nodes:
if mesh.has_node(far_node) and mesh_nodes[far_node].get('near') is not None and node in mesh_nodes[far_node]['near']:
mesh_nodes[far_node]['near'].remove(node)
near_nodes = [] if mesh_nodes[node].get('near') is None else mesh_nodes[node]['near']
for near_node in near_nodes:
if mesh.has_node(near_node) and mesh_nodes[near_node].get('far') is not None and node in mesh_nodes[near_node]['far']:
mesh_nodes[near_node]['far'].remove(node)
[mesh.remove_node(node) for node in remove_nodes]
return mesh, info_on_pix
def get_neighbors(mesh, node):
return [*mesh.neighbors(node)]
def generate_face(mesh, info_on_pix, config):
H, W = mesh.graph['H'], mesh.graph['W']
str_faces = []
num_node = len(mesh.nodes)
ply_flag = config.get('save_ply')
def out_fmt(input, cur_id_b, cur_id_self, cur_id_a, ply_flag):
if ply_flag is True:
input.append(' '.join(['3', cur_id_b, cur_id_self, cur_id_a]) + '\n')
else:
input.append([cur_id_b, cur_id_self, cur_id_a])
mesh_nodes = mesh.nodes
for node in mesh_nodes:
cur_id_self = mesh_nodes[node]['cur_id']
ne_nodes = get_neighbors(mesh, node)
four_dir_nes = {'up': [], 'left': [],
'down': [], 'right': []}
for ne_node in ne_nodes:
store_tuple = [ne_node, mesh_nodes[ne_node]['cur_id']]
if ne_node[0] == node[0]:
if ne_node[1] == ne_node[1] - 1:
four_dir_nes['left'].append(store_tuple)
else:
four_dir_nes['right'].append(store_tuple)
else:
if ne_node[0] == ne_node[0] - 1:
four_dir_nes['up'].append(store_tuple)
else:
four_dir_nes['down'].append(store_tuple)
for node_a, cur_id_a in four_dir_nes['up']:
for node_b, cur_id_b in four_dir_nes['right']:
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag)
for node_a, cur_id_a in four_dir_nes['right']:
for node_b, cur_id_b in four_dir_nes['down']:
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag)
for node_a, cur_id_a in four_dir_nes['down']:
for node_b, cur_id_b in four_dir_nes['left']:
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag)
for node_a, cur_id_a in four_dir_nes['left']:
for node_b, cur_id_b in four_dir_nes['up']:
out_fmt(str_faces, cur_id_b, cur_id_self, cur_id_a, ply_flag)
return str_faces
def reassign_floating_island(mesh, info_on_pix, image, depth):
H, W = mesh.graph['H'], mesh.graph['W'],
mesh_nodes = mesh.nodes
bord_up, bord_down = mesh.graph['bord_up'], mesh.graph['bord_down']
bord_left, bord_right = mesh.graph['bord_left'], mesh.graph['bord_right']
W = mesh.graph['W']
lost_map = np.zeros((H, W))
'''
(5) is_inside(x, y, xmin, xmax, ymin, ymax) : Check if a pixel(x, y) is inside the border.
(6) get_cross_nes(x, y) : Get the four cross neighbors of pixel(x, y).
'''
key_exist = lambda d, k: k in d
is_inside = lambda x, y, xmin, xmax, ymin, ymax: xmin <= x < xmax and ymin <= y < ymax
get_cross_nes = lambda x, y: [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)]
'''
(A) Highlight the pixels on isolated floating island.
(B) Number those isolated floating islands with connected component analysis.
(C) For each isolated island:
(1) Find its longest surrounded depth edge.
(2) Propogate depth from that depth edge to the pixels on the isolated island.
(3) Build the connection between the depth edge and that isolated island.
'''
for x in range(H):
for y in range(W):
if is_inside(x, y, bord_up, bord_down, bord_left, bord_right) and not(key_exist(info_on_pix, (x, y))):
lost_map[x, y] = 1
_, label_lost_map = cv2.connectedComponents(lost_map.astype(np.uint8), connectivity=4)
mask = np.zeros((H, W))
mask[bord_up:bord_down, bord_left:bord_right] = 1
label_lost_map = (label_lost_map * mask).astype(np.int)
for i in range(1, label_lost_map.max()+1):
lost_xs, lost_ys = np.where(label_lost_map == i)
surr_edge_ids = {}
for lost_x, lost_y in zip(lost_xs, lost_ys):
if (lost_x, lost_y) == (295, 389) or (lost_x, lost_y) == (296, 389):
import pdb; pdb.set_trace()
for ne in get_cross_nes(lost_x, lost_y):
if key_exist(info_on_pix, ne):
for info in info_on_pix[ne]:
ne_node = (ne[0], ne[1], info['depth'])
if key_exist(mesh_nodes[ne_node], 'edge_id'):
edge_id = mesh_nodes[ne_node]['edge_id']
surr_edge_ids[edge_id] = surr_edge_ids[edge_id] + [ne_node] if \
key_exist(surr_edge_ids, edge_id) else [ne_node]
if len(surr_edge_ids) == 0:
continue
edge_id, edge_nodes = sorted([*surr_edge_ids.items()], key=lambda x: len(x[1]), reverse=True)[0]
edge_depth_map = np.zeros((H, W))
for node in edge_nodes:
edge_depth_map[node[0], node[1]] = node[2]
lost_xs, lost_ys = np.where(label_lost_map == i)
while lost_xs.shape[0] > 0:
lost_xs, lost_ys = np.where(label_lost_map == i)
for lost_x, lost_y in zip(lost_xs, lost_ys):
propagated_depth = []
real_nes = []
for ne in get_cross_nes(lost_x, lost_y):
if not(is_inside(ne[0], ne[1], bord_up, bord_down, bord_left, bord_right)) or \
edge_depth_map[ne[0], ne[1]] == 0:
continue
propagated_depth.append(edge_depth_map[ne[0], ne[1]])
real_nes.append(ne)
if len(real_nes) == 0:
continue
reassign_depth = np.mean(propagated_depth)
label_lost_map[lost_x, lost_y] = 0
edge_depth_map[lost_x, lost_y] = reassign_depth
depth[lost_x, lost_y] = -reassign_depth
mesh.add_node((lost_x, lost_y, reassign_depth), color=image[lost_x, lost_y],
synthesis=False,
disp=1./reassign_depth,
cc_id=set())
info_on_pix[(lost_x, lost_y)] = [{'depth':reassign_depth,
'color':image[lost_x, lost_y],
'synthesis':False,
'disp':1./reassign_depth}]
new_connections = [((lost_x, lost_y, reassign_depth),
(ne[0], ne[1], edge_depth_map[ne[0], ne[1]])) for ne in real_nes]
mesh.add_edges_from(new_connections)
return mesh, info_on_pix, depth
def remove_node_feat(mesh, *feats):
mesh_nodes = mesh.nodes
for node in mesh_nodes:
for feat in feats:
mesh_nodes[node][feat] = None
return mesh
def update_status(mesh, info_on_pix, depth=None):
'''
(2) clear_node_feat(G, *fts) : Clear all the node feature on graph G.
(6) get_cross_nes(x, y) : Get the four cross neighbors of pixel(x, y).
'''
key_exist = lambda d, k: d.get(k) is not None
is_inside = lambda x, y, xmin, xmax, ymin, ymax: xmin <= x < xmax and ymin <= y < ymax
get_cross_nes = lambda x, y: [(x + 1, y), (x - 1, y), (x, y - 1), (x, y + 1)]
append_element = lambda d, k, x: d[k] + [x] if key_exist(d, k) else [x]
def clear_node_feat(G, fts):
le_nodes = G.nodes
for k in le_nodes:
v = le_nodes[k]
for ft in fts:
if ft in v:
v[ft] = None
clear_node_feat(mesh, ['edge_id', 'far', 'near'])
bord_up, bord_down = mesh.graph['bord_up'], mesh.graph['bord_down']
bord_left, bord_right = mesh.graph['bord_left'], mesh.graph['bord_right']
le_nodes = mesh.nodes
for node_key in le_nodes:
if mesh.neighbors(node_key).__length_hint__() == 4:
continue
four_nes = [xx for xx in get_cross_nes(node_key[0], node_key[1]) if
is_inside(xx[0], xx[1], bord_up, bord_down, bord_left, bord_right) and
xx in info_on_pix]
[four_nes.remove((ne_node[0], ne_node[1])) for ne_node in mesh.neighbors(node_key)]
for ne in four_nes:
for info in info_on_pix[ne]:
assert mesh.has_node((ne[0], ne[1], info['depth'])), "No node_key"
ind_node = le_nodes[node_key]
if abs(node_key[2]) > abs(info['depth']):
ind_node['near'] = append_element(ind_node, 'near', (ne[0], ne[1], info['depth']))
else:
ind_node['far'] = append_element(ind_node, 'far', (ne[0], ne[1], info['depth']))
if depth is not None:
for key, value in info_on_pix.items():
if depth[key[0], key[1]] != abs(value[0]['depth']):
value[0]['disp'] = 1. / value[0]['depth']
depth[key[0], key[1]] = abs(value[0]['depth'])
return mesh, depth, info_on_pix
else:
return mesh
def group_edges(LDI, config, image, remove_conflict_ordinal, spdb=False):
'''
(1) add_new_node(G, node) : add "node" to graph "G"
(2) add_new_edge(G, node_a, node_b) : add edge "node_a--node_b" to graph "G"
(3) exceed_thre(x, y, thre) : Check if difference between "x" and "y" exceed threshold "thre"
(4) key_exist(d, k) : Check if key "k' exists in dictionary "d"
(5) comm_opp_bg(G, x, y) : Check if node "x" and "y" in graph "G" treat the same opposite node as background
(6) comm_opp_fg(G, x, y) : Check if node "x" and "y" in graph "G" treat the same opposite node as foreground
'''
add_new_node = lambda G, node: None if G.has_node(node) else G.add_node(node)
add_new_edge = lambda G, node_a, node_b: None if G.has_edge(node_a, node_b) else G.add_edge(node_a, node_b)
exceed_thre = lambda x, y, thre: (abs(x) - abs(y)) > thre
key_exist = lambda d, k: d.get(k) is not None
comm_opp_bg = lambda G, x, y: key_exist(G.nodes[x], 'far') and key_exist(G.nodes[y], 'far') and \
not(set(G.nodes[x]['far']).isdisjoint(set(G.nodes[y]['far'])))
comm_opp_fg = lambda G, x, y: key_exist(G.nodes[x], 'near') and key_exist(G.nodes[y], 'near') and \
not(set(G.nodes[x]['near']).isdisjoint(set(G.nodes[y]['near'])))
discont_graph = netx.Graph()
'''
(A) Skip the pixel at image boundary, we don't want to deal with them.
(B) Identify discontinuity by the number of its neighbor(degree).
If the degree < 4(up/right/buttom/left). We will go through following steps:
(1) Add the discontinuity pixel "node" to graph "discont_graph".
(2) Find "node"'s cross neighbor(up/right/buttom/left) "ne_node".
- If the cross neighbor "ne_node" is a discontinuity pixel(degree("ne_node") < 4),
(a) add it to graph "discont_graph" and build the connection between "ne_node" and "node".
(b) label its cross neighbor as invalid pixels "inval_diag_candi" to avoid building
connection between original discontinuity pixel "node" and "inval_diag_candi".
- Otherwise, find "ne_node"'s cross neighbors, called diagonal candidate "diag_candi".
- The "diag_candi" is diagonal to the original discontinuity pixel "node".
- If "diag_candi" exists, go to step(3).
(3) A diagonal candidate "diag_candi" will be :
- added to the "discont_graph" if its degree < 4.
- connected to the original discontinuity pixel "node" if it satisfied either
one of following criterion:
(a) the difference of disparity between "diag_candi" and "node" is smaller than default threshold.
(b) the "diag_candi" and "node" face the same opposite pixel. (See. function "tear_edges")
(c) Both of "diag_candi" and "node" must_connect to each other. (See. function "combine_end_node")
(C) Aggregate each connected part in "discont_graph" into "discont_ccs" (A.K.A. depth edge).
'''
for node in LDI.nodes:
if not(LDI.graph['bord_up'] + 1 <= node[0] <= LDI.graph['bord_down'] - 2 and \
LDI.graph['bord_left'] + 1 <= node[1] <= LDI.graph['bord_right'] - 2):
continue
neighbors = [*LDI.neighbors(node)]
if len(neighbors) < 4:
add_new_node(discont_graph, node)
diag_candi_anc, inval_diag_candi, discont_nes = set(), set(), set()
for ne_node in neighbors:
if len([*LDI.neighbors(ne_node)]) < 4:
add_new_node(discont_graph, ne_node)
add_new_edge(discont_graph, ne_node, node)
discont_nes.add(ne_node)
else:
diag_candi_anc.add(ne_node)
inval_diag_candi = set([inval_diagonal for ne_node in discont_nes for inval_diagonal in LDI.neighbors(ne_node) if \
abs(inval_diagonal[0] - node[0]) < 2 and abs(inval_diagonal[1] - node[1]) < 2])
for ne_node in diag_candi_anc:
if ne_node[0] == node[0]:
diagonal_xys = [[ne_node[0] + 1, ne_node[1]], [ne_node[0] - 1, ne_node[1]]]
elif ne_node[1] == node[1]:
diagonal_xys = [[ne_node[0], ne_node[1] + 1], [ne_node[0], ne_node[1] - 1]]
for diag_candi in LDI.neighbors(ne_node):
if [diag_candi[0], diag_candi[1]] in diagonal_xys and LDI.degree(diag_candi) < 4:
if diag_candi not in inval_diag_candi:
if not exceed_thre(1./node[2], 1./diag_candi[2], config['depth_threshold']) or \
(comm_opp_bg(LDI, diag_candi, node) and comm_opp_fg(LDI, diag_candi, node)):
add_new_node(discont_graph, diag_candi)
add_new_edge(discont_graph, diag_candi, node)
if key_exist(LDI.nodes[diag_candi], 'must_connect') and node in LDI.nodes[diag_candi]['must_connect'] and \
key_exist(LDI.nodes[node], 'must_connect') and diag_candi in LDI.nodes[node]['must_connect']:
add_new_node(discont_graph, diag_candi)
add_new_edge(discont_graph, diag_candi, node)
if spdb == True:
import pdb; pdb.set_trace()
discont_ccs = [*netx.connected_components(discont_graph)]
'''
In some corner case, a depth edge "discont_cc" will contain both
foreground(FG) and background(BG) pixels. This violate the assumption that
a depth edge can only composite by one type of pixel(FG or BG).
We need to further divide this depth edge into several sub-part so that the
assumption is satisfied.
(A) A depth edge is invalid if both of its "far_flag"(BG) and
"near_flag"(FG) are True.
(B) If the depth edge is invalid, we need to do:
(1) Find the role("oridinal") of each pixel on the depth edge.
"-1" --> Its opposite pixels has smaller depth(near) than it.
It is a backgorund pixel.
"+1" --> Its opposite pixels has larger depth(far) than it.
It is a foregorund pixel.
"0" --> Some of opposite pixels has larger depth(far) than it,
and some has smaller pixel than it.
It is an ambiguous pixel.
(2) For each pixel "discont_node", check if its neigbhors' roles are consistent.
- If not, break the connection between the neighbor "ne_node" that has a role
different from "discont_node".
- If yes, remove all the role that are inconsistent to its neighbors "ne_node".
(3) Connected component analysis to re-identified those divided depth edge.
(C) Aggregate each connected part in "discont_graph" into "discont_ccs" (A.K.A. depth edge).
'''
if remove_conflict_ordinal:
new_discont_ccs = []
num_new_cc = 0
for edge_id, discont_cc in enumerate(discont_ccs):
near_flag = False
far_flag = False
for discont_node in discont_cc:
near_flag = True if key_exist(LDI.nodes[discont_node], 'far') else near_flag
far_flag = True if key_exist(LDI.nodes[discont_node], 'near') else far_flag
if far_flag and near_flag:
break
if far_flag and near_flag:
for discont_node in discont_cc:
discont_graph.nodes[discont_node]['ordinal'] = \
np.array([key_exist(LDI.nodes[discont_node], 'far'),
key_exist(LDI.nodes[discont_node], 'near')]) * \
np.array([-1, 1])
discont_graph.nodes[discont_node]['ordinal'] = \
np.sum(discont_graph.nodes[discont_node]['ordinal'])
remove_nodes, remove_edges = [], []
for discont_node in discont_cc:
ordinal_relation = np.sum([discont_graph.nodes[xx]['ordinal'] \
for xx in discont_graph.neighbors(discont_node)])
near_side = discont_graph.nodes[discont_node]['ordinal'] <= 0
if abs(ordinal_relation) < len([*discont_graph.neighbors(discont_node)]):
remove_nodes.append(discont_node)
for ne_node in discont_graph.neighbors(discont_node):
remove_flag = (near_side and not(key_exist(LDI.nodes[ne_node], 'far'))) or \
(not near_side and not(key_exist(LDI.nodes[ne_node], 'near')))
remove_edges += [(discont_node, ne_node)] if remove_flag else []
else:
if near_side and key_exist(LDI.nodes[discont_node], 'near'):
LDI.nodes[discont_node].pop('near')
elif not(near_side) and key_exist(LDI.nodes[discont_node], 'far'):
LDI.nodes[discont_node].pop('far')
discont_graph.remove_edges_from(remove_edges)
sub_mesh = discont_graph.subgraph(list(discont_cc)).copy()
sub_discont_ccs = [*netx.connected_components(sub_mesh)]
is_redun_near = lambda xx: len(xx) == 1 and xx[0] in remove_nodes and key_exist(LDI.nodes[xx[0]], 'far')
for sub_discont_cc in sub_discont_ccs:
if is_redun_near(list(sub_discont_cc)):
LDI.nodes[list(sub_discont_cc)[0]].pop('far')
new_discont_ccs.append(sub_discont_cc)
else:
new_discont_ccs.append(discont_cc)
discont_ccs = new_discont_ccs
new_discont_ccs = None
if spdb == True:
import pdb; pdb.set_trace()
for edge_id, edge_cc in enumerate(discont_ccs):
for node in edge_cc:
LDI.nodes[node]['edge_id'] = edge_id
return discont_ccs, LDI, discont_graph
def combine_end_node(mesh, edge_mesh, edge_ccs, depth):
import collections
mesh_nodes = mesh.nodes
connect_dict = dict()
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs):
connect_info = []
for valid_edge_node in valid_edge_cc:
single_connect = set()
for ne_node in mesh.neighbors(valid_edge_node):
if mesh_nodes[ne_node].get('far') is not None:
for fn in mesh_nodes[ne_node].get('far'):
if mesh.has_node(fn) and mesh_nodes[fn].get('edge_id') is not None:
single_connect.add(mesh_nodes[fn]['edge_id'])
if mesh_nodes[ne_node].get('near') is not None:
for fn in mesh_nodes[ne_node].get('near'):
if mesh.has_node(fn) and mesh_nodes[fn].get('edge_id') is not None:
single_connect.add(mesh_nodes[fn]['edge_id'])
connect_info.extend([*single_connect])
connect_dict[valid_edge_id] = collections.Counter(connect_info)
end_maps = np.zeros((mesh.graph['H'], mesh.graph['W']))
edge_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs):
for valid_edge_node in valid_edge_cc:
edge_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id
if len([*edge_mesh.neighbors(valid_edge_node)]) == 1:
num_ne = 1
if num_ne == 1:
end_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_node[2]
nxs, nys = np.where(end_maps != 0)
invalid_nodes = set()
for nx, ny in zip(nxs, nys):
if mesh.has_node((nx, ny, end_maps[nx, ny])) is False:
invalid_nodes.add((nx, ny))
continue
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \
end_maps[xx[0], xx[1]] != 0]
mesh_nes = [*mesh.neighbors((nx, ny, end_maps[nx, ny]))]
remove_num = 0
for fne in four_nes:
if (fne[0], fne[1], end_maps[fne[0], fne[1]]) in mesh_nes:
remove_num += 1
if remove_num == len(four_nes):
invalid_nodes.add((nx, ny))
for invalid_node in invalid_nodes:
end_maps[invalid_node[0], invalid_node[1]] = 0
nxs, nys = np.where(end_maps != 0)
invalid_nodes = set()
for nx, ny in zip(nxs, nys):
if mesh_nodes[(nx, ny, end_maps[nx, ny])].get('edge_id') is None:
continue
else:
self_id = mesh_nodes[(nx, ny, end_maps[nx, ny])].get('edge_id')
self_connect = connect_dict[self_id] if connect_dict.get(self_id) is not None else dict()
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \
end_maps[xx[0], xx[1]] != 0]
for fne in four_nes:
if mesh_nodes[(fne[0], fne[1], end_maps[fne[0], fne[1]])].get('edge_id') is None:
continue
else:
ne_id = mesh_nodes[(fne[0], fne[1], end_maps[fne[0], fne[1]])]['edge_id']
if self_connect.get(ne_id) is None or self_connect.get(ne_id) == 1:
continue
else:
invalid_nodes.add((nx, ny))
for invalid_node in invalid_nodes:
end_maps[invalid_node[0], invalid_node[1]] = 0
nxs, nys = np.where(end_maps != 0)
invalid_nodes = set()
for nx, ny in zip(nxs, nys):
four_nes = [xx for xx in [(nx - 1, ny), (nx + 1, ny), (nx, ny - 1), (nx, ny + 1)] \
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \
end_maps[xx[0], xx[1]] != 0]
for fne in four_nes:
if mesh.has_node((fne[0], fne[1], end_maps[fne[0], fne[1]])):
node_a, node_b = (fne[0], fne[1], end_maps[fne[0], fne[1]]), (nx, ny, end_maps[nx, ny])
mesh.add_edge(node_a, node_b)
mesh_nodes[node_b]['must_connect'] = set() if mesh_nodes[node_b].get('must_connect') is None else mesh_nodes[node_b]['must_connect']
mesh_nodes[node_b]['must_connect'].add(node_a)
mesh_nodes[node_b]['must_connect'] |= set([xx for xx in [*edge_mesh.neighbors(node_a)] if \
(xx[0] - node_b[0]) < 2 and (xx[1] - node_b[1]) < 2])
mesh_nodes[node_a]['must_connect'] = set() if mesh_nodes[node_a].get('must_connect') is None else mesh_nodes[node_a]['must_connect']
mesh_nodes[node_a]['must_connect'].add(node_b)
mesh_nodes[node_a]['must_connect'] |= set([xx for xx in [*edge_mesh.neighbors(node_b)] if \
(xx[0] - node_a[0]) < 2 and (xx[1] - node_a[1]) < 2])
invalid_nodes.add((nx, ny))
for invalid_node in invalid_nodes:
end_maps[invalid_node[0], invalid_node[1]] = 0
return mesh
def remove_redundant_edge(mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=1000, invalid=False, spdb=False):
point_to_amount = {}
point_to_id = {}
end_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1
for valid_edge_id, valid_edge_cc in enumerate(edge_ccs):
for valid_edge_node in valid_edge_cc:
point_to_amount[valid_edge_node] = len(valid_edge_cc)
point_to_id[valid_edge_node] = valid_edge_id
if edge_mesh.has_node(valid_edge_node) is True:
if len([*edge_mesh.neighbors(valid_edge_node)]) == 1:
end_maps[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id
nxs, nys = np.where(end_maps > -1)
point_to_adjoint = {}
for nx, ny in zip(nxs, nys):
adjoint_edges = set([end_maps[x, y] for x, y in [(nx + 1, ny), (nx - 1, ny), (nx, ny + 1), (nx, ny - 1)] if end_maps[x, y] != -1])
point_to_adjoint[end_maps[nx, ny]] = (point_to_adjoint[end_maps[nx, ny]] | adjoint_edges) if point_to_adjoint.get(end_maps[nx, ny]) is not None else adjoint_edges
valid_edge_ccs = filter_edge(mesh, edge_ccs, config, invalid=invalid)
edge_canvas = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1
for valid_edge_id, valid_edge_cc in enumerate(valid_edge_ccs):
for valid_edge_node in valid_edge_cc:
edge_canvas[valid_edge_node[0], valid_edge_node[1]] = valid_edge_id
if spdb is True:
plt.imshow(edge_canvas); plt.show()
import pdb; pdb.set_trace()
for valid_edge_id, valid_edge_cc in enumerate(valid_edge_ccs):
end_number = 0
four_end_number = 0
eight_end_number = 0
db_eight_end_number = 0
if len(valid_edge_cc) > redundant_number:
continue
for valid_edge_node in valid_edge_cc:
if len([*edge_mesh.neighbors(valid_edge_node)]) == 3:
break
elif len([*edge_mesh.neighbors(valid_edge_node)]) == 1:
hx, hy, hz = valid_edge_node
if invalid is False:
eight_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1),
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id]
if len(eight_nes) == 0:
end_number += 1
if invalid is True:
four_nes = []; eight_nes = []; db_eight_nes = []
four_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id]
eight_nes = [(x, y) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1), \
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id]
db_eight_nes = [(x, y) for x in range(hx - 2, hx + 3) for y in range(hy - 2, hy + 3) \
if info_on_pix.get((x, y)) is not None and edge_canvas[x, y] != -1 and edge_canvas[x, y] != valid_edge_id and (x, y) != (hx, hy)]
if len(four_nes) == 0 or len(eight_nes) == 0:
end_number += 1
if len(four_nes) == 0:
four_end_number += 1
if len(eight_nes) == 0:
eight_end_number += 1
if len(db_eight_nes) == 0:
db_eight_end_number += 1
elif len([*edge_mesh.neighbors(valid_edge_node)]) == 0:
hx, hy, hz = valid_edge_node
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \
if info_on_pix.get((x, y)) is not None and \
mesh.has_edge(valid_edge_node, (x, y, info_on_pix[(x, y)][0]['depth'])) is False]
for ne in four_nes:
try:
if invalid is True or (point_to_amount.get(ne) is None or point_to_amount[ne] < redundant_number) or \
point_to_id[ne] in point_to_adjoint.get(point_to_id[valid_edge_node], set()):
mesh.add_edge(valid_edge_node, ne)
except:
import pdb; pdb.set_trace()
if (invalid is not True and end_number >= 1) or (invalid is True and end_number >= 2 and eight_end_number >= 1 and db_eight_end_number >= 1):
for valid_edge_node in valid_edge_cc:
hx, hy, _ = valid_edge_node
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \
if info_on_pix.get((x, y)) is not None and \
mesh.has_edge(valid_edge_node, (x, y, info_on_pix[(x, y)][0]['depth'])) is False and \
(edge_canvas[x, y] == -1 or edge_canvas[x, y] == valid_edge_id)]
for ne in four_nes:
if invalid is True or (point_to_amount.get(ne) is None or point_to_amount[ne] < redundant_number) or \
point_to_id[ne] in point_to_adjoint.get(point_to_id[valid_edge_node], set()):
mesh.add_edge(valid_edge_node, ne)
return mesh
def judge_dangle(mark, mesh, node):
if not (1 <= node[0] < mesh.graph['H']-1) or not(1 <= node[1] < mesh.graph['W']-1):
return mark
mesh_neighbors = [*mesh.neighbors(node)]
mesh_neighbors = [xx for xx in mesh_neighbors if 0 < xx[0] < mesh.graph['H'] - 1 and 0 < xx[1] < mesh.graph['W'] - 1]
if len(mesh_neighbors) >= 3:
return mark
elif len(mesh_neighbors) <= 1:
mark[node[0], node[1]] = (len(mesh_neighbors) + 1)
else:
dan_ne_node_a = mesh_neighbors[0]
dan_ne_node_b = mesh_neighbors[1]
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1:
mark[node[0], node[1]] = 3
return mark
def remove_dangling(mesh, edge_ccs, edge_mesh, info_on_pix, image, depth, config):
tmp_edge_ccs = copy.deepcopy(edge_ccs)
for edge_cc_id, valid_edge_cc in enumerate(tmp_edge_ccs):
if len(valid_edge_cc) > 1 or len(valid_edge_cc) == 0:
continue
single_edge_node = [*valid_edge_cc][0]
hx, hy, hz = single_edge_node
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1),
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None])
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \
if info_on_pix.get((x, y)) is not None]
sub_mesh = mesh.subgraph(eight_nes).copy()
ccs = netx.connected_components(sub_mesh)
four_ccs = []
for cc_id, _cc in enumerate(ccs):
four_ccs.append(set())
for cc_node in _cc:
if abs(cc_node[0] - hx) + abs(cc_node[1] - hy) < 2:
four_ccs[cc_id].add(cc_node)
largest_cc = sorted(four_ccs, key=lambda x: (len(x), -np.sum([abs(xx[2] - hz) for xx in x])))[-1]
if len(largest_cc) < 2:
for ne in four_nes:
mesh.add_edge(single_edge_node, ne)
else:
mesh.remove_edges_from([(single_edge_node, ne) for ne in mesh.neighbors(single_edge_node)])
new_depth = np.mean([xx[2] for xx in largest_cc])
info_on_pix[(hx, hy)][0]['depth'] = new_depth
info_on_pix[(hx, hy)][0]['disp'] = 1./new_depth
new_node = (hx, hy, new_depth)
mesh = refresh_node(single_edge_node, mesh.node[single_edge_node], new_node, dict(), mesh)
edge_ccs[edge_cc_id] = set([new_node])
for ne in largest_cc:
mesh.add_edge(new_node, ne)
mark = np.zeros((mesh.graph['H'], mesh.graph['W']))
for edge_idx, edge_cc in enumerate(edge_ccs):
for edge_node in edge_cc:
if not (mesh.graph['bord_up'] <= edge_node[0] < mesh.graph['bord_down']-1) or \
not (mesh.graph['bord_left'] <= edge_node[1] < mesh.graph['bord_right']-1):
continue
mesh_neighbors = [*mesh.neighbors(edge_node)]
mesh_neighbors = [xx for xx in mesh_neighbors \
if mesh.graph['bord_up'] < xx[0] < mesh.graph['bord_down'] - 1 and \
mesh.graph['bord_left'] < xx[1] < mesh.graph['bord_right'] - 1]
if len([*mesh.neighbors(edge_node)]) >= 3:
continue
elif len([*mesh.neighbors(edge_node)]) <= 1:
mark[edge_node[0], edge_node[1]] += (len([*mesh.neighbors(edge_node)]) + 1)
else:
dan_ne_node_a = [*mesh.neighbors(edge_node)][0]
dan_ne_node_b = [*mesh.neighbors(edge_node)][1]
if abs(dan_ne_node_a[0] - dan_ne_node_b[0]) > 1 or \
abs(dan_ne_node_a[1] - dan_ne_node_b[1]) > 1:
mark[edge_node[0], edge_node[1]] += 3
mxs, mys = np.where(mark == 1)
conn_0_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']))]
mxs, mys = np.where(mark == 2)
conn_1_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']))]
for node in conn_0_nodes:
hx, hy = node[0], node[1]
four_nes = [(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1)] \
if info_on_pix.get((x, y)) is not None]
re_depth = {'value' : 0, 'count': 0}
for ne in four_nes:
mesh.add_edge(node, ne)
re_depth['value'] += cc_node[2]
re_depth['count'] += 1.
re_depth = re_depth['value'] / re_depth['count']
mapping_dict = {node: (node[0], node[1], re_depth)}
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh)
depth[node[0], node[1]] = abs(re_depth)
mark[node[0], node[1]] = 0
for node in conn_1_nodes:
hx, hy = node[0], node[1]
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1),
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None])
self_nes = set([ne2 for ne1 in mesh.neighbors(node) for ne2 in mesh.neighbors(ne1) if ne2 in eight_nes])
eight_nes = [*(eight_nes - self_nes)]
sub_mesh = mesh.subgraph(eight_nes).copy()
ccs = netx.connected_components(sub_mesh)
largest_cc = sorted(ccs, key=lambda x: (len(x), -np.sum([abs(xx[0] - node[0]) + abs(xx[1] - node[1]) for xx in x])))[-1]
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)])
re_depth = {'value' : 0, 'count': 0}
for cc_node in largest_cc:
if cc_node[0] == node[0] and cc_node[1] == node[1]:
continue
re_depth['value'] += cc_node[2]
re_depth['count'] += 1.
if abs(cc_node[0] - node[0]) + abs(cc_node[1] - node[1]) < 2:
mesh.add_edge(cc_node, node)
try:
re_depth = re_depth['value'] / re_depth['count']
except:
re_depth = node[2]
renode = (node[0], node[1], re_depth)
mapping_dict = {node: renode}
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh)
depth[node[0], node[1]] = abs(re_depth)
mark[node[0], node[1]] = 0
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, renode, mark)
mxs, mys = np.where(mark == 3)
conn_2_nodes = [(x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth']) for x in zip(mxs, mys) \
if mesh.has_node((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth'])) and \
mesh.degree((x[0], x[1], info_on_pix[(x[0], x[1])][0]['depth'])) == 2]
sub_mesh = mesh.subgraph(conn_2_nodes).copy()
ccs = netx.connected_components(sub_mesh)
for cc in ccs:
candidate_nodes = [xx for xx in cc if sub_mesh.degree(xx) == 1]
for node in candidate_nodes:
if mesh.has_node(node) is False:
continue
ne_node = [xx for xx in mesh.neighbors(node) if xx not in cc][0]
hx, hy = node[0], node[1]
eight_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1),
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None and (x, y, info_on_pix[(x, y)][0]['depth']) not in cc])
ne_sub_mesh = mesh.subgraph(eight_nes).copy()
ne_ccs = netx.connected_components(ne_sub_mesh)
try:
ne_cc = [ne_cc for ne_cc in ne_ccs if ne_node in ne_cc][0]
except:
import pdb; pdb.set_trace()
largest_cc = [xx for xx in ne_cc if abs(xx[0] - node[0]) + abs(xx[1] - node[1]) == 1]
mesh.remove_edges_from([(xx, node) for xx in mesh.neighbors(node)])
re_depth = {'value' : 0, 'count': 0}
for cc_node in largest_cc:
re_depth['value'] += cc_node[2]
re_depth['count'] += 1.
mesh.add_edge(cc_node, node)
try:
re_depth = re_depth['value'] / re_depth['count']
except:
re_depth = node[2]
renode = (node[0], node[1], re_depth)
mapping_dict = {node: renode}
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh)
depth[node[0], node[1]] = abs(re_depth)
mark[node[0], node[1]] = 0
edge_mesh, mesh, mark, info_on_pix = recursive_add_edge(edge_mesh, mesh, info_on_pix, renode, mark)
break
if len(cc) == 1:
node = [node for node in cc][0]
hx, hy = node[0], node[1]
nine_nes = set([(x, y, info_on_pix[(x, y)][0]['depth']) for x, y in [(hx, hy), (hx + 1, hy), (hx - 1, hy), (hx, hy + 1), (hx, hy - 1),
(hx + 1, hy + 1), (hx - 1, hy - 1), (hx - 1, hy + 1), (hx + 1, hy - 1)] \
if info_on_pix.get((x, y)) is not None and mesh.has_node((x, y, info_on_pix[(x, y)][0]['depth']))])
ne_sub_mesh = mesh.subgraph(nine_nes).copy()
ne_ccs = netx.connected_components(ne_sub_mesh)
for ne_cc in ne_ccs:
if node in ne_cc:
re_depth = {'value' : 0, 'count': 0}
for ne in ne_cc:
if abs(ne[0] - node[0]) + abs(ne[1] - node[1]) == 1:
mesh.add_edge(node, ne)
re_depth['value'] += ne[2]
re_depth['count'] += 1.
re_depth = re_depth['value'] / re_depth['count']
mapping_dict = {node: (node[0], node[1], re_depth)}
info_on_pix, mesh, edge_mesh = update_info(mapping_dict, info_on_pix, mesh, edge_mesh)
depth[node[0], node[1]] = abs(re_depth)
mark[node[0], node[1]] = 0
return mesh, info_on_pix, edge_mesh, depth, mark
def context_and_holes(mesh, edge_ccs, config, specific_edge_id, specific_edge_loc, depth_feat_model,
connect_points_ccs=None, inpaint_iter=0, filter_edge=False, vis_edge_id=None):
edge_maps = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1
mask_info = {}
for edge_id, edge_cc in enumerate(edge_ccs):
for edge_node in edge_cc:
edge_maps[edge_node[0], edge_node[1]] = edge_id
context_ccs = [set() for x in range(len(edge_ccs))]
extend_context_ccs = [set() for x in range(len(edge_ccs))]
extend_erode_context_ccs = [set() for x in range(len(edge_ccs))]
extend_edge_ccs = [set() for x in range(len(edge_ccs))]
accomp_extend_context_ccs = [set() for x in range(len(edge_ccs))]
erode_context_ccs = [set() for x in range(len(edge_ccs))]
broken_mask_ccs = [set() for x in range(len(edge_ccs))]
invalid_extend_edge_ccs = [set() for x in range(len(edge_ccs))]
intouched_ccs = [set() for x in range(len(edge_ccs))]
redundant_ccs = [set() for x in range(len(edge_ccs))]
if inpaint_iter == 0:
background_thickness = config['background_thickness']
context_thickness = config['context_thickness']
else:
background_thickness = config['background_thickness_2']
context_thickness = config['context_thickness_2']
mesh_nodes = mesh.nodes
for edge_id, edge_cc in enumerate(edge_ccs):
if context_thickness == 0 or (len(specific_edge_id) > 0 and edge_id not in specific_edge_id):
continue
edge_group = {}
for edge_node in edge_cc:
far_nodes = mesh_nodes[edge_node].get('far')
if far_nodes is None:
continue
for far_node in far_nodes:
if far_node in edge_cc:
continue
context_ccs[edge_id].add(far_node)
if mesh_nodes[far_node].get('edge_id') is not None:
if edge_group.get(mesh_nodes[far_node]['edge_id']) is None:
edge_group[mesh_nodes[far_node]['edge_id']] = set()
edge_group[mesh_nodes[far_node]['edge_id']].add(far_node)
if len(edge_cc) > 2:
for edge_key in [*edge_group.keys()]:
if len(edge_group[edge_key]) == 1:
context_ccs[edge_id].remove([*edge_group[edge_key]][0])
for edge_id, edge_cc in enumerate(edge_ccs):
if inpaint_iter != 0:
continue
tmp_intouched_nodes = set()
for edge_node in edge_cc:
raw_intouched_nodes = set(mesh_nodes[edge_node].get('near')) if mesh_nodes[edge_node].get('near') is not None else set()
tmp_intouched_nodes |= set([xx for xx in raw_intouched_nodes if mesh_nodes[xx].get('edge_id') is not None and \
len(context_ccs[mesh_nodes[xx].get('edge_id')]) > 0])
intouched_ccs[edge_id] |= tmp_intouched_nodes
tmp_intouched_nodes = None
mask_ccs = copy.deepcopy(edge_ccs)
forbidden_len = 3
forbidden_map = np.ones((mesh.graph['H'] - forbidden_len, mesh.graph['W'] - forbidden_len))
forbidden_map = np.pad(forbidden_map, ((forbidden_len, forbidden_len), (forbidden_len, forbidden_len)), mode='constant').astype(np.bool)
cur_tmp_mask_map = np.zeros_like(forbidden_map).astype(np.bool)
passive_background = 10 if 10 is not None else background_thickness
passive_context = 1 if 1 is not None else context_thickness
for edge_id, edge_cc in enumerate(edge_ccs):
cur_mask_cc = None; cur_mask_cc = []
cur_context_cc = None; cur_context_cc = []
cur_accomp_near_cc = None; cur_accomp_near_cc = []
cur_invalid_extend_edge_cc = None; cur_invalid_extend_edge_cc = []
cur_comp_far_cc = None; cur_comp_far_cc = []
tmp_erode = []
if len(context_ccs[edge_id]) == 0 or (len(specific_edge_id) > 0 and edge_id not in specific_edge_id):
continue
for i in range(max(background_thickness, context_thickness)):
cur_tmp_mask_map.fill(False)
if i == 0:
tmp_mask_nodes = copy.deepcopy(mask_ccs[edge_id])
tmp_intersect_nodes = []
tmp_intersect_context_nodes = []
mask_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
context_depth = np.zeros((mesh.graph['H'], mesh.graph['W']))
comp_cnt_depth = np.zeros((mesh.graph['H'], mesh.graph['W']))
connect_map = np.zeros((mesh.graph['H'], mesh.graph['W']))
for node in tmp_mask_nodes:
mask_map[node[0], node[1]] = True
depth_count = 0
if mesh_nodes[node].get('far') is not None:
for comp_cnt_node in mesh_nodes[node]['far']:
comp_cnt_depth[node[0], node[1]] += abs(comp_cnt_node[2])
depth_count += 1
if depth_count > 0:
comp_cnt_depth[node[0], node[1]] = comp_cnt_depth[node[0], node[1]] / depth_count
connect_node = []
if mesh_nodes[node].get('connect_point_id') is not None:
connect_node.append(mesh_nodes[node]['connect_point_id'])
connect_point_id = np.bincount(connect_node).argmax() if len(connect_node) > 0 else -1
if connect_point_id > -1 and connect_points_ccs is not None:
for xx in connect_points_ccs[connect_point_id]:
if connect_map[xx[0], xx[1]] == 0:
connect_map[xx[0], xx[1]] = xx[2]
if mesh_nodes[node].get('connect_point_exception') is not None:
for xx in mesh_nodes[node]['connect_point_exception']:
if connect_map[xx[0], xx[1]] == 0:
connect_map[xx[0], xx[1]] = xx[2]
tmp_context_nodes = [*context_ccs[edge_id]]
tmp_erode.append([*context_ccs[edge_id]])
context_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
if (context_map.astype(np.uint8) * mask_map.astype(np.uint8)).max() > 0:
import pdb; pdb.set_trace()
for node in tmp_context_nodes:
context_map[node[0], node[1]] = True
context_depth[node[0], node[1]] = node[2]
context_map[mask_map == True] = False
if (context_map.astype(np.uint8) * mask_map.astype(np.uint8)).max() > 0:
import pdb; pdb.set_trace()
tmp_intouched_nodes = [*intouched_ccs[edge_id]]
intouched_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
for node in tmp_intouched_nodes: intouched_map[node[0], node[1]] = True
intouched_map[mask_map == True] = False
tmp_redundant_nodes = set()
tmp_noncont_nodes = set()
noncont_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
intersect_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
intersect_context_map = np.zeros((mesh.graph['H'], mesh.graph['W']), dtype=np.bool)
if i > passive_background and inpaint_iter == 0:
new_tmp_intersect_nodes = None
new_tmp_intersect_nodes = []
for node in tmp_intersect_nodes:
nes = mesh.neighbors(node)
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True and \
bool(intouched_map[ne[0], ne[1]]) is False and\
bool(intersect_map[ne[0], ne[1]]) is False and\
bool(intersect_context_map[ne[0], ne[1]]) is False:
break_flag = False
if (i - passive_background) % 2 == 0 and (i - passive_background) % 8 != 0:
four_nes = [xx for xx in[[ne[0] - 1, ne[1]], [ne[0] + 1, ne[1]], [ne[0], ne[1] - 1], [ne[0], ne[1] + 1]] \
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W']]
for fne in four_nes:
if bool(mask_map[fne[0], fne[1]]) is True:
break_flag = True
break
if break_flag is True:
continue
intersect_map[ne[0], ne[1]] = True
new_tmp_intersect_nodes.append(ne)
tmp_intersect_nodes = None
tmp_intersect_nodes = new_tmp_intersect_nodes
if i > passive_context and inpaint_iter == 1:
new_tmp_intersect_context_nodes = None
new_tmp_intersect_context_nodes = []
for node in tmp_intersect_context_nodes:
nes = mesh.neighbors(node)
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True and \
bool(intouched_map[ne[0], ne[1]]) is False and\
bool(intersect_map[ne[0], ne[1]]) is False and \
bool(intersect_context_map[ne[0], ne[1]]) is False:
intersect_context_map[ne[0], ne[1]] = True
new_tmp_intersect_context_nodes.append(ne)
tmp_intersect_context_nodes = None
tmp_intersect_context_nodes = new_tmp_intersect_context_nodes
new_tmp_mask_nodes = None
new_tmp_mask_nodes = []
for node in tmp_mask_nodes:
four_nes = {xx:[] for xx in [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)] if \
0 <= xx[0] < connect_map.shape[0] and 0 <= xx[1] < connect_map.shape[1]}
if inpaint_iter > 0:
for ne in four_nes.keys():
if connect_map[ne[0], ne[1]] == True:
tmp_context_nodes.append((ne[0], ne[1], connect_map[ne[0], ne[1]]))
context_map[ne[0], ne[1]] = True
nes = mesh.neighbors(node)
if inpaint_iter > 0:
for ne in nes: four_nes[(ne[0], ne[1])].append(ne[2])
nes = []
for kfne, vfnes in four_nes.items(): vfnes.sort(key = lambda xx: abs(xx), reverse=True)
for kfne, vfnes in four_nes.items():
for vfne in vfnes: nes.append((kfne[0], kfne[1], vfne))
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True and \
bool(intouched_map[ne[0], ne[1]]) is False and \
bool(intersect_map[ne[0], ne[1]]) is False and \
bool(intersect_context_map[ne[0], ne[1]]) is False:
if i == passive_background and inpaint_iter == 0:
if np.any(context_map[max(ne[0] - 1, 0):min(ne[0] + 2, mesh.graph['H']), max(ne[1] - 1, 0):min(ne[1] + 2, mesh.graph['W'])]) == True:
intersect_map[ne[0], ne[1]] = True
tmp_intersect_nodes.append(ne)
continue
if i < background_thickness:
if inpaint_iter == 0:
cur_mask_cc.append(ne)
elif mesh_nodes[ne].get('inpaint_id') == 1:
cur_mask_cc.append(ne)
else:
continue
mask_ccs[edge_id].add(ne)
if inpaint_iter == 0:
if comp_cnt_depth[node[0], node[1]] > 0 and comp_cnt_depth[ne[0], ne[1]] == 0:
comp_cnt_depth[ne[0], ne[1]] = comp_cnt_depth[node[0], node[1]]
if mesh_nodes[ne].get('far') is not None:
for comp_far_node in mesh_nodes[ne]['far']:
cur_comp_far_cc.append(comp_far_node)
cur_accomp_near_cc.append(ne)
cur_invalid_extend_edge_cc.append(comp_far_node)
if mesh_nodes[ne].get('edge_id') is not None and \
len(context_ccs[mesh_nodes[ne].get('edge_id')]) > 0:
intouched_fars = set(mesh_nodes[ne].get('far')) if mesh_nodes[ne].get('far') is not None else set()
accum_intouched_fars = set(intouched_fars)
for intouched_far in intouched_fars:
accum_intouched_fars |= set([*mesh.neighbors(intouched_far)])
for intouched_far in accum_intouched_fars:
if bool(mask_map[intouched_far[0], intouched_far[1]]) is True or \
bool(context_map[intouched_far[0], intouched_far[1]]) is True:
continue
tmp_redundant_nodes.add(intouched_far)
intouched_map[intouched_far[0], intouched_far[1]] = True
if mesh_nodes[ne].get('near') is not None:
intouched_nears = set(mesh_nodes[ne].get('near'))
for intouched_near in intouched_nears:
if bool(mask_map[intouched_near[0], intouched_near[1]]) is True or \
bool(context_map[intouched_near[0], intouched_near[1]]) is True:
continue
tmp_redundant_nodes.add(intouched_near)
intouched_map[intouched_near[0], intouched_near[1]] = True
if not (mesh_nodes[ne].get('inpaint_id') != 1 and inpaint_iter == 1):
new_tmp_mask_nodes.append(ne)
mask_map[ne[0], ne[1]] = True
tmp_mask_nodes = new_tmp_mask_nodes
new_tmp_context_nodes = None
new_tmp_context_nodes = []
for node in tmp_context_nodes:
nes = mesh.neighbors(node)
if inpaint_iter > 0:
four_nes = {(node[0] - 1, node[1]):[], (node[0] + 1, node[1]):[], (node[0], node[1] - 1):[], (node[0], node[1] + 1):[]}
for ne in nes: four_nes[(ne[0], ne[1])].append(ne[2])
nes = []
for kfne, vfnes in four_nes.items(): vfnes.sort(key = lambda xx: abs(xx), reverse=True)
for kfne, vfnes in four_nes.items():
for vfne in vfnes: nes.append((kfne[0], kfne[1], vfne))
for ne in nes:
mask_flag = (bool(mask_map[ne[0], ne[1]]) is False)
if bool(context_map[ne[0], ne[1]]) is False and mask_flag and \
bool(forbidden_map[ne[0], ne[1]]) is True and bool(noncont_map[ne[0], ne[1]]) is False and \
bool(intersect_context_map[ne[0], ne[1]]) is False:
if i == passive_context and inpaint_iter == 1:
mnes = mesh.neighbors(ne)
if any([mask_map[mne[0], mne[1]] == True for mne in mnes]) is True:
intersect_context_map[ne[0], ne[1]] = True
tmp_intersect_context_nodes.append(ne)
continue
if False and mesh_nodes[ne].get('near') is not None and mesh_nodes[ne].get('edge_id') != edge_id:
noncont_nears = set(mesh_nodes[ne].get('near'))
for noncont_near in noncont_nears:
if bool(context_map[noncont_near[0], noncont_near[1]]) is False:
tmp_noncont_nodes.add(noncont_near)
noncont_map[noncont_near[0], noncont_near[1]] = True
new_tmp_context_nodes.append(ne)
context_map[ne[0], ne[1]] = True
context_depth[ne[0], ne[1]] = ne[2]
cur_context_cc.extend(new_tmp_context_nodes)
tmp_erode.append(new_tmp_context_nodes)
tmp_context_nodes = None
tmp_context_nodes = new_tmp_context_nodes
new_tmp_intouched_nodes = None; new_tmp_intouched_nodes = []
for node in tmp_intouched_nodes:
if bool(context_map[node[0], node[1]]) is True or bool(mask_map[node[0], node[1]]) is True:
continue
nes = mesh.neighbors(node)
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(intouched_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
new_tmp_intouched_nodes.append(ne)
intouched_map[ne[0], ne[1]] = True
tmp_intouched_nodes = None
tmp_intouched_nodes = set(new_tmp_intouched_nodes)
new_tmp_redundant_nodes = None; new_tmp_redundant_nodes = []
for node in tmp_redundant_nodes:
if bool(context_map[node[0], node[1]]) is True or \
bool(mask_map[node[0], node[1]]) is True:
continue
nes = mesh.neighbors(node)
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(intouched_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
new_tmp_redundant_nodes.append(ne)
intouched_map[ne[0], ne[1]] = True
tmp_redundant_nodes = None
tmp_redundant_nodes = set(new_tmp_redundant_nodes)
new_tmp_noncont_nodes = None; new_tmp_noncont_nodes = []
for node in tmp_noncont_nodes:
if bool(context_map[node[0], node[1]]) is True or \
bool(mask_map[node[0], node[1]]) is True:
continue
nes = mesh.neighbors(node)
rmv_flag = False
for ne in nes:
if bool(context_map[ne[0], ne[1]]) is False and \
bool(mask_map[ne[0], ne[1]]) is False and \
bool(noncont_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
patch_context_map = context_map[max(ne[0] - 1, 0):min(ne[0] + 2, context_map.shape[0]),
max(ne[1] - 1, 0):min(ne[1] + 2, context_map.shape[1])]
if bool(np.any(patch_context_map)) is True:
new_tmp_noncont_nodes.append(ne)
noncont_map[ne[0], ne[1]] = True
tmp_noncont_nodes = None
tmp_noncont_nodes = set(new_tmp_noncont_nodes)
if inpaint_iter == 0:
depth_dict = get_depth_from_maps(context_map, mask_map, context_depth, mesh.graph['H'], mesh.graph['W'], log_depth=config['log_depth'])
mask_size = get_valid_size(depth_dict['mask'])
mask_size = dilate_valid_size(mask_size, depth_dict['mask'], dilate=[20, 20])
context_size = get_valid_size(depth_dict['context'])
context_size = dilate_valid_size(context_size, depth_dict['context'], dilate=[20, 20])
union_size = size_operation(mask_size, context_size, operation='+')
depth_dict = depth_inpainting(None, None, None, None, mesh, config, union_size, depth_feat_model, None, given_depth_dict=depth_dict, spdb=False)
near_depth_map, raw_near_depth_map = np.zeros((mesh.graph['H'], mesh.graph['W'])), np.zeros((mesh.graph['H'], mesh.graph['W']))
filtered_comp_far_cc, filtered_accomp_near_cc = set(), set()
for node in cur_accomp_near_cc:
near_depth_map[node[0], node[1]] = depth_dict['output'][node[0], node[1]]
raw_near_depth_map[node[0], node[1]] = node[2]
for node in cur_comp_far_cc:
four_nes = [xx for xx in [(node[0] - 1, node[1]), (node[0] + 1, node[1]), (node[0], node[1] - 1), (node[0], node[1] + 1)] \
if 0 <= xx[0] < mesh.graph['H'] and 0 <= xx[1] < mesh.graph['W'] and \
near_depth_map[xx[0], xx[1]] != 0 and \
abs(near_depth_map[xx[0], xx[1]]) < abs(node[2])]
if len(four_nes) > 0:
filtered_comp_far_cc.add(node)
for ne in four_nes:
filtered_accomp_near_cc.add((ne[0], ne[1], -abs(raw_near_depth_map[ne[0], ne[1]])))
cur_comp_far_cc, cur_accomp_near_cc = filtered_comp_far_cc, filtered_accomp_near_cc
mask_ccs[edge_id] |= set(cur_mask_cc)
context_ccs[edge_id] |= set(cur_context_cc)
accomp_extend_context_ccs[edge_id] |= set(cur_accomp_near_cc).intersection(cur_mask_cc)
extend_edge_ccs[edge_id] |= set(cur_accomp_near_cc).intersection(cur_mask_cc)
extend_context_ccs[edge_id] |= set(cur_comp_far_cc)
invalid_extend_edge_ccs[edge_id] |= set(cur_invalid_extend_edge_cc)
erode_size = [0]
for tmp in tmp_erode:
erode_size.append(len(tmp))
if len(erode_size) > 1:
erode_size[-1] += erode_size[-2]
if inpaint_iter == 0:
tmp_width = config['depth_edge_dilate']
else:
tmp_width = 0
while float(erode_size[tmp_width]) / (erode_size[-1] + 1e-6) > 0.3:
tmp_width = tmp_width - 1
try:
if tmp_width == 0:
erode_context_ccs[edge_id] = set([])
else:
erode_context_ccs[edge_id] = set(reduce(lambda x, y : x + y, [] + tmp_erode[:tmp_width]))
except:
import pdb; pdb.set_trace()
erode_context_cc = copy.deepcopy(erode_context_ccs[edge_id])
for erode_context_node in erode_context_cc:
if (inpaint_iter != 0 and (mesh_nodes[erode_context_node].get('inpaint_id') is None or
mesh_nodes[erode_context_node].get('inpaint_id') == 0)):
erode_context_ccs[edge_id].remove(erode_context_node)
else:
context_ccs[edge_id].remove(erode_context_node)
context_map = np.zeros((mesh.graph['H'], mesh.graph['W']))
for context_node in context_ccs[edge_id]:
context_map[context_node[0], context_node[1]] = 1
extend_context_ccs[edge_id] = extend_context_ccs[edge_id] - mask_ccs[edge_id] - accomp_extend_context_ccs[edge_id]
if inpaint_iter == 0:
all_ecnt_cc = set()
for ecnt_id, ecnt_cc in enumerate(extend_context_ccs):
constraint_context_ids = set()
constraint_context_cc = set()
constraint_erode_context_cc = set()
tmp_mask_cc = set()
accum_context_cc = None; accum_context_cc = []
for ecnt_node in accomp_extend_context_ccs[ecnt_id]:
if edge_maps[ecnt_node[0], ecnt_node[1]] > -1:
constraint_context_ids.add(int(round(edge_maps[ecnt_node[0], ecnt_node[1]])))
constraint_erode_context_cc = erode_context_ccs[ecnt_id]
for constraint_context_id in constraint_context_ids:
constraint_context_cc = constraint_context_cc | context_ccs[constraint_context_id] | erode_context_ccs[constraint_context_id]
constraint_erode_context_cc = constraint_erode_context_cc | erode_context_ccs[constraint_context_id]
for i in range(background_thickness):
if i == 0:
tmp_context_nodes = copy.deepcopy(ecnt_cc)
tmp_invalid_context_nodes = copy.deepcopy(invalid_extend_edge_ccs[ecnt_id])
tmp_mask_nodes = copy.deepcopy(accomp_extend_context_ccs[ecnt_id])
tmp_context_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool)
tmp_mask_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool)
tmp_invalid_context_map = np.zeros((mesh.graph['H'], mesh.graph['W'])).astype(np.bool)
for node in tmp_mask_nodes:
tmp_mask_map[node[0], node[1]] = True
for node in context_ccs[ecnt_id]:
tmp_context_map[node[0], node[1]] = True
for node in erode_context_ccs[ecnt_id]:
tmp_context_map[node[0], node[1]] = True
for node in extend_context_ccs[ecnt_id]:
tmp_context_map[node[0], node[1]] = True
for node in invalid_extend_edge_ccs[ecnt_id]:
tmp_invalid_context_map[node[0], node[1]] = True
init_invalid_context_map = tmp_invalid_context_map.copy()
init_context_map = tmp
if (tmp_mask_map.astype(np.uint8) * tmp_context_map.astype(np.uint8)).max() > 0:
import pdb; pdb.set_trace()
if vis_edge_id is not None and ecnt_id == vis_edge_id:
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True)
ax1.imshow(tmp_context_map * 1); ax2.imshow(init_invalid_context_map * 1 + tmp_context_map * 2)
plt.show()
import pdb; pdb.set_trace()
else:
tmp_context_nodes = new_tmp_context_nodes
new_tmp_context_nodes = None
tmp_mask_nodes = new_tmp_mask_nodes
new_tmp_mask_nodes = None
tmp_invalid_context_nodes = new_tmp_invalid_context_nodes
new_tmp_invalid_context_nodes = None
new_tmp_context_nodes = None
new_tmp_context_nodes = []
new_tmp_invalid_context_nodes = None
new_tmp_invalid_context_nodes = []
new_tmp_mask_nodes = set([])
for node in tmp_context_nodes:
for ne in mesh.neighbors(node):
if ne in constraint_context_cc and \
bool(tmp_mask_map[ne[0], ne[1]]) is False and \
bool(tmp_context_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
new_tmp_context_nodes.append(ne)
tmp_context_map[ne[0], ne[1]] = True
accum_context_cc.extend(new_tmp_context_nodes)
for node in tmp_invalid_context_nodes:
for ne in mesh.neighbors(node):
if bool(tmp_mask_map[ne[0], ne[1]]) is False and \
bool(tmp_context_map[ne[0], ne[1]]) is False and \
bool(tmp_invalid_context_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
tmp_invalid_context_map[ne[0], ne[1]] = True
new_tmp_invalid_context_nodes.append(ne)
for node in tmp_mask_nodes:
for ne in mesh.neighbors(node):
if bool(tmp_mask_map[ne[0], ne[1]]) is False and \
bool(tmp_context_map[ne[0], ne[1]]) is False and \
bool(tmp_invalid_context_map[ne[0], ne[1]]) is False and \
bool(forbidden_map[ne[0], ne[1]]) is True:
new_tmp_mask_nodes.add(ne)
tmp_mask_map[ne[0], ne[1]] = True
init_invalid_context_map[tmp_context_map] = False
_, tmp_label_map = cv2.connectedComponents((init_invalid_context_map | tmp_context_map).astype(np.uint8), connectivity=8)
tmp_label_ids = set(np.unique(tmp_label_map[init_invalid_context_map]))
if (tmp_mask_map.astype(np.uint8) * tmp_context_map.astype(np.uint8)).max() > 0:
import pdb; pdb.set_trace()
if vis_edge_id is not None and ecnt_id == vis_edge_id:
f, ((ax1, ax2)) = plt.subplots(1, 2, sharex=True, sharey=True)
ax1.imshow(tmp_label_map); ax2.imshow(init_invalid_context_map * 1 + tmp_context_map * 2)
plt.show()
import pdb; pdb.set_trace()
extend_context_ccs[ecnt_id] |= set(accum_context_cc)
extend_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] - mask_ccs[ecnt_id]
extend_erode_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] & constraint_erode_context_cc
extend_context_ccs[ecnt_id] = extend_context_ccs[ecnt_id] - extend_erode_context_ccs[ecnt_id] - erode_context_ccs[ecnt_id]
tmp_context_cc = context_ccs[ecnt_id] - extend_erode_context_ccs[ecnt_id] - erode_context_ccs[ecnt_id]
if len(tmp_context_cc) > 0:
context_ccs[ecnt_id] = tmp_context_cc
tmp_mask_cc = tmp_mask_cc - context_ccs[ecnt_id] - erode_context_ccs[ecnt_id]
mask_ccs[ecnt_id] = mask_ccs[ecnt_id] | tmp_mask_cc
return context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, invalid_extend_edge_ccs, edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs
def DL_inpaint_edge(mesh,
info_on_pix,
config,
image,
depth,
context_ccs,
erode_context_ccs,
extend_context_ccs,
extend_erode_context_ccs,
mask_ccs,
broken_mask_ccs,
edge_ccs,
extend_edge_ccs,
init_mask_connect,
edge_maps,
rgb_model=None,
depth_edge_model=None,
depth_edge_model_init=None,
depth_feat_model=None,
specific_edge_id=-1,
specific_edge_loc=None,
inpaint_iter=0):
if isinstance(config["gpu_ids"], int) and (config["gpu_ids"] >= 0):
device = config["gpu_ids"]
else:
device = "cpu"
edge_map = np.zeros_like(depth)
new_edge_ccs = [set() for _ in range(len(edge_ccs))]
edge_maps_with_id = edge_maps
edge_condition = lambda x, m: m.nodes[x].get('far') is not None and len(m.nodes[x].get('far')) > 0
edge_map = get_map_from_ccs(edge_ccs, mesh.graph['H'], mesh.graph['W'], mesh, edge_condition)
np_depth, np_image = depth.copy(), image.copy()
image_c = image.shape[-1]
image = torch.FloatTensor(image.transpose(2, 0, 1)).unsqueeze(0).to(device)
if depth.ndim < 3:
depth = depth[..., None]
depth = torch.FloatTensor(depth.transpose(2, 0, 1)).unsqueeze(0).to(device)
mesh.graph['max_edge_id'] = len(edge_ccs)
connnect_points_ccs = [set() for _ in range(len(edge_ccs))]
gp_time, tmp_mesh_time, bilateral_time = 0, 0, 0
edges_infos = dict()
edges_in_mask = [set() for _ in range(len(edge_ccs))]
tmp_specific_edge_id = []
for edge_id, (context_cc, mask_cc, erode_context_cc, extend_context_cc, edge_cc) in enumerate(zip(context_ccs, mask_ccs, erode_context_ccs, extend_context_ccs, edge_ccs)):
if len(specific_edge_id) > 0:
if edge_id not in specific_edge_id:
continue
if len(context_cc) < 1 or len(mask_cc) < 1:
continue
edge_dict = get_edge_from_nodes(context_cc | extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, edge_cc, extend_edge_ccs[edge_id],
mesh.graph['H'], mesh.graph['W'], mesh)
edge_dict['edge'], end_depth_maps, _ = \
filter_irrelevant_edge_new(edge_dict['self_edge'], edge_dict['comp_edge'],
edge_map,
edge_maps_with_id,
edge_id,
edge_dict['context'],
edge_dict['depth'], mesh, context_cc | erode_context_cc | extend_context_cc | extend_erode_context_ccs[edge_id], spdb=False)
if specific_edge_loc is not None and \
(specific_edge_loc is not None and edge_dict['mask'][specific_edge_loc[0], specific_edge_loc[1]] == 0):
continue
mask_size = get_valid_size(edge_dict['mask'])
mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20])
context_size = get_valid_size(edge_dict['context'])
context_size = dilate_valid_size(context_size, edge_dict['context'], dilate=[20, 20])
union_size = size_operation(mask_size, context_size, operation='+')
patch_edge_dict = dict()
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \
patch_edge_dict['disp'], patch_edge_dict['edge'] = \
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'],
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge'])
x_anchor, y_anchor = [union_size['x_min'], union_size['x_max']], [union_size['y_min'], union_size['y_max']]
tensor_edge_dict = convert2tensor(patch_edge_dict)
input_edge_feat = torch.cat((tensor_edge_dict['rgb'],
tensor_edge_dict['disp'],
tensor_edge_dict['edge'],
1 - tensor_edge_dict['context'],
tensor_edge_dict['mask']), dim=1)
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0:
with torch.no_grad():
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'],
tensor_edge_dict['context'],
tensor_edge_dict['rgb'],
tensor_edge_dict['disp'],
tensor_edge_dict['edge'],
unit_length=128,
cuda=device)
depth_edge_output = depth_edge_output.cpu()
tensor_edge_dict['output'] = (depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge']
else:
tensor_edge_dict['output'] = tensor_edge_dict['edge']
depth_edge_output = tensor_edge_dict['edge'] + 0
patch_edge_dict['output'] = tensor_edge_dict['output'].squeeze().data.cpu().numpy()
edge_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W']))
edge_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
patch_edge_dict['output']
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0:
if ((depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask']).max() > 0:
try:
edge_dict['fpath_map'], edge_dict['npath_map'], break_flag, npaths, fpaths, invalid_edge_id = \
clean_far_edge_new(edge_dict['output'], end_depth_maps, edge_dict['mask'], edge_dict['context'], mesh, info_on_pix, edge_dict['self_edge'], inpaint_iter, config)
except:
import pdb; pdb.set_trace()
pre_npath_map = edge_dict['npath_map'].copy()
if config.get('repeat_inpaint_edge') is True:
for _ in range(2):
tmp_input_edge = ((edge_dict['npath_map'] > -1) + edge_dict['edge']).clip(0, 1)
patch_tmp_input_edge = crop_maps_by_size(union_size, tmp_input_edge)[0]
tensor_input_edge = torch.FloatTensor(patch_tmp_input_edge)[None, None, ...]
depth_edge_output = depth_edge_model.forward_3P(tensor_edge_dict['mask'],
tensor_edge_dict['context'],
tensor_edge_dict['rgb'],
tensor_edge_dict['disp'],
tensor_input_edge,
unit_length=128,
cuda=device)
depth_edge_output = depth_edge_output.cpu()
depth_edge_output = (depth_edge_output> config['ext_edge_threshold']).float() * tensor_edge_dict['mask'] + tensor_edge_dict['edge']
depth_edge_output = depth_edge_output.squeeze().data.cpu().numpy()
full_depth_edge_output = np.zeros((mesh.graph['H'], mesh.graph['W']))
full_depth_edge_output[union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
depth_edge_output
edge_dict['fpath_map'], edge_dict['npath_map'], break_flag, npaths, fpaths, invalid_edge_id = \
clean_far_edge_new(full_depth_edge_output, end_depth_maps, edge_dict['mask'], edge_dict['context'], mesh, info_on_pix, edge_dict['self_edge'], inpaint_iter, config)
for nid in npaths.keys():
npath, fpath = npaths[nid], fpaths[nid]
start_mx, start_my, end_mx, end_my = -1, -1, -1, -1
if end_depth_maps[npath[0][0], npath[0][1]] != 0:
start_mx, start_my = npath[0][0], npath[0][1]
if end_depth_maps[npath[-1][0], npath[-1][1]] != 0:
end_mx, end_my = npath[-1][0], npath[-1][1]
if start_mx == -1:
import pdb; pdb.set_trace()
valid_end_pt = () if end_mx == -1 else (end_mx, end_my, info_on_pix[(end_mx, end_my)][0]['depth'])
new_edge_info = dict(fpath=fpath,
npath=npath,
cont_end_pts=valid_end_pt,
mask_id=edge_id,
comp_edge_id=nid,
depth=end_depth_maps[start_mx, start_my])
if edges_infos.get((start_mx, start_my)) is None:
edges_infos[(start_mx, start_my)] = []
edges_infos[(start_mx, start_my)].append(new_edge_info)
edges_in_mask[edge_id].add((start_mx, start_my))
if len(valid_end_pt) > 0:
new_edge_info = dict(fpath=fpath[::-1],
npath=npath[::-1],
cont_end_pts=(start_mx, start_my, info_on_pix[(start_mx, start_my)][0]['depth']),
mask_id=edge_id,
comp_edge_id=nid,
depth=end_depth_maps[end_mx, end_my])
if edges_infos.get((end_mx, end_my)) is None:
edges_infos[(end_mx, end_my)] = []
edges_infos[(end_mx, end_my)].append(new_edge_info)
edges_in_mask[edge_id].add((end_mx, end_my))
for edge_id, (context_cc, mask_cc, erode_context_cc, extend_context_cc, edge_cc) in enumerate(zip(context_ccs, mask_ccs, erode_context_ccs, extend_context_ccs, edge_ccs)):
if len(specific_edge_id) > 0:
if edge_id not in specific_edge_id:
continue
if len(context_cc) < 1 or len(mask_cc) < 1:
continue
edge_dict = get_edge_from_nodes(context_cc | extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, edge_cc, extend_edge_ccs[edge_id],
mesh.graph['H'], mesh.graph['W'], mesh)
if specific_edge_loc is not None and \
(specific_edge_loc is not None and edge_dict['mask'][specific_edge_loc[0], specific_edge_loc[1]] == 0):
continue
else:
tmp_specific_edge_id.append(edge_id)
edge_dict['edge'], end_depth_maps, _ = \
filter_irrelevant_edge_new(edge_dict['self_edge'], edge_dict['comp_edge'],
edge_map,
edge_maps_with_id,
edge_id,
edge_dict['context'],
edge_dict['depth'], mesh, context_cc | erode_context_cc | extend_context_cc | extend_erode_context_ccs[edge_id], spdb=False)
discard_map = np.zeros_like(edge_dict['edge'])
mask_size = get_valid_size(edge_dict['mask'])
mask_size = dilate_valid_size(mask_size, edge_dict['mask'], dilate=[20, 20])
context_size = get_valid_size(edge_dict['context'])
context_size = dilate_valid_size(context_size, edge_dict['context'], dilate=[20, 20])
union_size = size_operation(mask_size, context_size, operation='+')
patch_edge_dict = dict()
patch_edge_dict['mask'], patch_edge_dict['context'], patch_edge_dict['rgb'], \
patch_edge_dict['disp'], patch_edge_dict['edge'] = \
crop_maps_by_size(union_size, edge_dict['mask'], edge_dict['context'],
edge_dict['rgb'], edge_dict['disp'], edge_dict['edge'])
x_anchor, y_anchor = [union_size['x_min'], union_size['x_max']], [union_size['y_min'], union_size['y_max']]
tensor_edge_dict = convert2tensor(patch_edge_dict)
input_edge_feat = torch.cat((tensor_edge_dict['rgb'],
tensor_edge_dict['disp'],
tensor_edge_dict['edge'],
1 - tensor_edge_dict['context'],
tensor_edge_dict['mask']), dim=1)
edge_dict['output'] = edge_dict['edge'].copy()
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) and inpaint_iter == 0:
edge_dict['fpath_map'], edge_dict['npath_map'] = edge_dict['fpath_map'] * 0 - 1, edge_dict['npath_map'] * 0 - 1
end_pts = edges_in_mask[edge_id]
for end_pt in end_pts:
cur_edge_infos = edges_infos[(end_pt[0], end_pt[1])]
cur_info = [xx for xx in cur_edge_infos if xx['mask_id'] == edge_id][0]
other_infos = [xx for xx in cur_edge_infos if xx['mask_id'] != edge_id and len(xx['cont_end_pts']) > 0]
if len(cur_info['cont_end_pts']) > 0 or (len(cur_info['cont_end_pts']) == 0 and len(other_infos) == 0):
for fnode in cur_info['fpath']:
edge_dict['fpath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id']
for fnode in cur_info['npath']:
edge_dict['npath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id']
fnmap = edge_dict['fpath_map'] * 1
fnmap[edge_dict['npath_map'] != -1] = edge_dict['npath_map'][edge_dict['npath_map'] != -1]
for end_pt in end_pts:
cur_edge_infos = edges_infos[(end_pt[0], end_pt[1])]
cur_info = [xx for xx in cur_edge_infos if xx['mask_id'] == edge_id][0]
cur_depth = cur_info['depth']
other_infos = [xx for xx in cur_edge_infos if xx['mask_id'] != edge_id and len(xx['cont_end_pts']) > 0]
comp_edge_id = cur_info['comp_edge_id']
if len(cur_info['cont_end_pts']) == 0 and len(other_infos) > 0:
other_infos = sorted(other_infos, key=lambda aa: abs(abs(aa['cont_end_pts'][2]) - abs(cur_depth)))
for other_info in other_infos:
tmp_fmap, tmp_nmap = np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1, np.zeros((mesh.graph['H'], mesh.graph['W'])) - 1
for fnode in other_info['fpath']:
if fnmap[fnode[0], fnode[1]] != -1:
tmp_fmap = tmp_fmap * 0 - 1
break
else:
tmp_fmap[fnode[0], fnode[1]] = comp_edge_id
if fnmap[fnode[0], fnode[1]] != -1:
continue
for fnode in other_info['npath']:
if fnmap[fnode[0], fnode[1]] != -1:
tmp_nmap = tmp_nmap * 0 - 1
break
else:
tmp_nmap[fnode[0], fnode[1]] = comp_edge_id
if fnmap[fnode[0], fnode[1]] != -1:
continue
break
if min(tmp_fmap.max(), tmp_nmap.max()) != -1:
edge_dict['fpath_map'] = tmp_fmap
edge_dict['fpath_map'][edge_dict['valid_area'] == 0] = -1
edge_dict['npath_map'] = tmp_nmap
edge_dict['npath_map'][edge_dict['valid_area'] == 0] = -1
discard_map = ((tmp_nmap != -1).astype(np.uint8) + (tmp_fmap != -1).astype(np.uint8)) * edge_dict['mask']
else:
for fnode in cur_info['fpath']:
edge_dict['fpath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id']
for fnode in cur_info['npath']:
edge_dict['npath_map'][fnode[0], fnode[1]] = cur_info['comp_edge_id']
if edge_dict['npath_map'].min() == 0 or edge_dict['fpath_map'].min() == 0:
import pdb; pdb.set_trace()
edge_dict['output'] = (edge_dict['npath_map'] > -1) * edge_dict['mask'] + edge_dict['context'] * edge_dict['edge']
mesh, _, _, _ = create_placeholder(edge_dict['context'], edge_dict['mask'],
edge_dict['depth'], edge_dict['fpath_map'],
edge_dict['npath_map'], mesh, inpaint_iter,
edge_ccs,
extend_edge_ccs[edge_id],
edge_maps_with_id,
edge_id)
dxs, dys = np.where(discard_map != 0)
for dx, dy in zip(dxs, dys):
mesh.nodes[(dx, dy)]['inpaint_twice'] = False
depth_dict = depth_inpainting(context_cc, extend_context_cc, erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, mesh, config, union_size, depth_feat_model, edge_dict['output'])
refine_depth_output = depth_dict['output']*depth_dict['mask']
for near_id in np.unique(edge_dict['npath_map'])[1:]:
refine_depth_output = refine_depth_around_edge(refine_depth_output.copy(),
(edge_dict['fpath_map'] == near_id).astype(np.uint8) * edge_dict['mask'],
(edge_dict['fpath_map'] == near_id).astype(np.uint8),
(edge_dict['npath_map'] == near_id).astype(np.uint8) * edge_dict['mask'],
depth_dict['mask'].copy(),
depth_dict['output'] * depth_dict['context'],
config)
depth_dict['output'][depth_dict['mask'] > 0] = refine_depth_output[depth_dict['mask'] > 0]
rgb_dict = get_rgb_from_nodes(context_cc | extend_context_cc,
erode_context_cc | extend_erode_context_ccs[edge_id], mask_cc, mesh.graph['H'], mesh.graph['W'], mesh)
if np.all(rgb_dict['mask'] == edge_dict['mask']) is False:
import pdb; pdb.set_trace()
rgb_dict['edge'] = edge_dict['output']
patch_rgb_dict = dict()
patch_rgb_dict['mask'], patch_rgb_dict['context'], patch_rgb_dict['rgb'], \
patch_rgb_dict['edge'] = crop_maps_by_size(union_size, rgb_dict['mask'],
rgb_dict['context'], rgb_dict['rgb'],
rgb_dict['edge'])
tensor_rgb_dict = convert2tensor(patch_rgb_dict)
resize_rgb_dict = {k: v.clone() for k, v in tensor_rgb_dict.items()}
max_hw = np.array([*patch_rgb_dict['mask'].shape[-2:]]).max()
init_frac = config['largest_size'] / (np.array([*patch_rgb_dict['mask'].shape[-2:]]).prod() ** 0.5)
resize_hw = [patch_rgb_dict['mask'].shape[-2] * init_frac, patch_rgb_dict['mask'].shape[-1] * init_frac]
resize_max_hw = max(resize_hw)
frac = (np.floor(resize_max_hw / 128.) * 128.) / max_hw
if frac < 1:
resize_mark = torch.nn.functional.interpolate(torch.cat((resize_rgb_dict['mask'],
resize_rgb_dict['context']),
dim=1),
scale_factor=frac,
mode='area')
resize_rgb_dict['mask'] = (resize_mark[:, 0:1] > 0).float()
resize_rgb_dict['context'] = (resize_mark[:, 1:2] == 1).float()
resize_rgb_dict['context'][resize_rgb_dict['mask'] > 0] = 0
resize_rgb_dict['rgb'] = torch.nn.functional.interpolate(resize_rgb_dict['rgb'],
scale_factor=frac,
mode='area')
resize_rgb_dict['rgb'] = resize_rgb_dict['rgb'] * resize_rgb_dict['context']
resize_rgb_dict['edge'] = torch.nn.functional.interpolate(resize_rgb_dict['edge'],
scale_factor=frac,
mode='area')
resize_rgb_dict['edge'] = (resize_rgb_dict['edge'] > 0).float() * 0
resize_rgb_dict['edge'] = resize_rgb_dict['edge'] * (resize_rgb_dict['context'] + resize_rgb_dict['mask'])
rgb_input_feat = torch.cat((resize_rgb_dict['rgb'], resize_rgb_dict['edge']), dim=1)
rgb_input_feat[:, 3] = 1 - rgb_input_feat[:, 3]
resize_mask = open_small_mask(resize_rgb_dict['mask'], resize_rgb_dict['context'], 3, 41)
specified_hole = resize_mask
with torch.no_grad():
rgb_output = rgb_model.forward_3P(specified_hole,
resize_rgb_dict['context'],
resize_rgb_dict['rgb'],
resize_rgb_dict['edge'],
unit_length=128,
cuda=device)
rgb_output = rgb_output.cpu()
if config.get('gray_image') is True:
rgb_output = rgb_output.mean(1, keepdim=True).repeat((1,3,1,1))
rgb_output = rgb_output.cpu()
resize_rgb_dict['output'] = rgb_output * resize_rgb_dict['mask'] + resize_rgb_dict['rgb']
tensor_rgb_dict['output'] = resize_rgb_dict['output']
if frac < 1:
tensor_rgb_dict['output'] = torch.nn.functional.interpolate(tensor_rgb_dict['output'],
size=tensor_rgb_dict['mask'].shape[-2:],
mode='bicubic')
tensor_rgb_dict['output'] = tensor_rgb_dict['output'] * \
tensor_rgb_dict['mask'] + (tensor_rgb_dict['rgb'] * tensor_rgb_dict['context'])
patch_rgb_dict['output'] = tensor_rgb_dict['output'].data.cpu().numpy().squeeze().transpose(1,2,0)
rgb_dict['output'] = np.zeros((mesh.graph['H'], mesh.graph['W'], 3))
rgb_dict['output'][union_size['x_min']:union_size['x_max'], union_size['y_min']:union_size['y_max']] = \
patch_rgb_dict['output']
if require_depth_edge(patch_edge_dict['edge'], patch_edge_dict['mask']) or inpaint_iter > 0:
edge_occlusion = True
else:
edge_occlusion = False
for node in erode_context_cc:
if rgb_dict['mask'][node[0], node[1]] > 0:
for info in info_on_pix[(node[0], node[1])]:
if abs(info['depth']) == abs(node[2]):
info['update_color'] = (rgb_dict['output'][node[0], node[1]] * 255).astype(np.uint8)
if frac < 1.:
depth_edge_dilate_2_color_flag = False
else:
depth_edge_dilate_2_color_flag = True
hxs, hys = np.where((rgb_dict['mask'] > 0) & (rgb_dict['erode'] == 0))
for hx, hy in zip(hxs, hys):
real_depth = None
if abs(depth_dict['output'][hx, hy]) <= abs(np_depth[hx, hy]):
depth_dict['output'][hx, hy] = np_depth[hx, hy] + 0.01
node = (hx, hy, -depth_dict['output'][hx, hy])
if info_on_pix.get((node[0], node[1])) is not None:
for info in info_on_pix.get((node[0], node[1])):
if info.get('inpaint_id') is None or abs(info['inpaint_id'] < mesh.nodes[(hx, hy)]['inpaint_id']):
pre_depth = info['depth'] if info.get('real_depth') is None else info['real_depth']
if abs(node[2]) < abs(pre_depth):
node = (node[0], node[1], -(abs(pre_depth) + 0.001))
if mesh.has_node(node):
real_depth = node[2]
while True:
if mesh.has_node(node):
node = (node[0], node[1], -(abs(node[2]) + 0.001))
else:
break
if real_depth == node[2]:
real_depth = None
cur_disp = 1./node[2]
if not(mesh.has_node(node)):
if not mesh.has_node((node[0], node[1])):
print("2D node not found.")
import pdb; pdb.set_trace()
if inpaint_iter == 1:
paint = (rgb_dict['output'][hx, hy] * 255).astype(np.uint8)
else:
paint = (rgb_dict['output'][hx, hy] * 255).astype(np.uint8)
ndict = dict(color=paint,
synthesis=True,
disp=cur_disp,
cc_id=set([edge_id]),
overlap_number=1.0,
refine_depth=False,
edge_occlusion=edge_occlusion,
depth_edge_dilate_2_color_flag=depth_edge_dilate_2_color_flag,
real_depth=real_depth)
mesh, _, _ = refresh_node((node[0], node[1]), mesh.nodes[(node[0], node[1])], node, ndict, mesh, stime=True)
if inpaint_iter == 0 and mesh.degree(node) < 4:
connnect_points_ccs[edge_id].add(node)
if info_on_pix.get((hx, hy)) is None:
info_on_pix[(hx, hy)] = []
new_info = {'depth':node[2],
'color': paint,
'synthesis':True,
'disp':cur_disp,
'cc_id':set([edge_id]),
'inpaint_id':inpaint_iter + 1,
'edge_occlusion':edge_occlusion,
'overlap_number':1.0,
'real_depth': real_depth}
info_on_pix[(hx, hy)].append(new_info)
specific_edge_id = tmp_specific_edge_id
for erode_id, erode_context_cc in enumerate(erode_context_ccs):
if len(specific_edge_id) > 0 and erode_id not in specific_edge_id:
continue
for erode_node in erode_context_cc:
for info in info_on_pix[(erode_node[0], erode_node[1])]:
if info['depth'] == erode_node[2]:
info['color'] = info['update_color']
mesh.nodes[erode_node]['color'] = info['update_color']
np_image[(erode_node[0], erode_node[1])] = info['update_color']
new_edge_ccs = [set() for _ in range(mesh.graph['max_edge_id'] + 1)]
for node in mesh.nodes:
if len(node) == 2:
mesh.remove_node(node)
continue
if mesh.nodes[node].get('edge_id') is not None and mesh.nodes[node].get('inpaint_id') == inpaint_iter + 1:
if mesh.nodes[node].get('inpaint_twice') is False:
continue
try:
new_edge_ccs[mesh.nodes[node].get('edge_id')].add(node)
except:
import pdb; pdb.set_trace()
specific_mask_nodes = None
if inpaint_iter == 0:
mesh, info_on_pix = refine_color_around_edge(mesh, info_on_pix, new_edge_ccs, config, False)
return mesh, info_on_pix, specific_mask_nodes, new_edge_ccs, connnect_points_ccs, np_image
def write_ply(image,
depth,
int_mtx,
ply_name,
config,
rgb_model,
depth_edge_model,
depth_edge_model_init,
depth_feat_model):
depth = depth.astype(np.float64)
input_mesh, xy2depth, image, depth = create_mesh(depth, image, int_mtx, config)
H, W = input_mesh.graph['H'], input_mesh.graph['W']
input_mesh = tear_edges(input_mesh, config['depth_threshold'], xy2depth)
input_mesh, info_on_pix = generate_init_node(input_mesh, config, min_node_in_cc=200)
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=False)
edge_canvas = np.zeros((H, W)) - 1
input_mesh, info_on_pix, depth = reassign_floating_island(input_mesh, info_on_pix, image, depth)
input_mesh = update_status(input_mesh, info_on_pix)
specific_edge_id = []
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
pre_depth = depth.copy()
input_mesh, info_on_pix, edge_mesh, depth, aft_mark = remove_dangling(input_mesh, edge_ccs, edge_mesh, info_on_pix, image, depth, config)
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth)
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
edge_canvas = np.zeros((H, W)) - 1
mesh, info_on_pix, depth = fill_missing_node(input_mesh, info_on_pix, image, depth)
if config['extrapolate_border'] is True:
pre_depth = depth.copy()
input_mesh, info_on_pix, depth = refresh_bord_depth(input_mesh, info_on_pix, image, depth)
input_mesh = remove_node_feat(input_mesh, 'edge_id')
aft_depth = depth.copy()
input_mesh, info_on_pix, depth, image = enlarge_border(input_mesh, info_on_pix, depth, image, config)
noext_H, noext_W = H, W
H, W = image.shape[:2]
input_mesh, info_on_pix = fill_dummy_bord(input_mesh, info_on_pix, image, depth, config)
edge_ccs, input_mesh, edge_mesh = \
group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth)
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth)
edge_ccs, input_mesh, edge_mesh = \
group_edges(input_mesh, config, image, remove_conflict_ordinal=True, spdb=False)
input_mesh = remove_redundant_edge(input_mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=config['redundant_number'], spdb=False)
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth)
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth)
input_mesh = remove_redundant_edge(input_mesh, edge_mesh, edge_ccs, info_on_pix, config, redundant_number=config['redundant_number'], invalid=True, spdb=False)
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth)
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
input_mesh = combine_end_node(input_mesh, edge_mesh, edge_ccs, depth)
input_mesh, depth, info_on_pix = update_status(input_mesh, info_on_pix, depth)
edge_ccs, input_mesh, edge_mesh = group_edges(input_mesh, config, image, remove_conflict_ordinal=True)
edge_condition = lambda x, m: m.nodes[x].get('far') is not None and len(m.nodes[x].get('far')) > 0
edge_map = get_map_from_ccs(edge_ccs, input_mesh.graph['H'], input_mesh.graph['W'], input_mesh, edge_condition)
other_edge_with_id = get_map_from_ccs(edge_ccs, input_mesh.graph['H'], input_mesh.graph['W'], real_id=True)
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="up")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="left")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="down")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="right")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="right-up")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="right-down")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="left-up")
info_on_pix, input_mesh, image, depth, edge_ccs = extrapolate(input_mesh, info_on_pix, image, depth, other_edge_with_id, edge_map, edge_ccs,
depth_edge_model, depth_feat_model, rgb_model, config, direc="left-down")
specific_edge_loc = None
specific_edge_id = []
vis_edge_id = None
context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, \
init_mask_connect, edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs = \
context_and_holes(input_mesh,
edge_ccs,
config,
specific_edge_id,
specific_edge_loc,
depth_feat_model,
inpaint_iter=0,
vis_edge_id=vis_edge_id)
edge_canvas = np.zeros((H, W))
mask = np.zeros((H, W))
context = np.zeros((H, W))
vis_edge_ccs = filter_edge(input_mesh, edge_ccs, config)
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) - 1
specific_edge_loc = None
FG_edge_maps = edge_maps.copy()
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) - 1
# for cc_id, cc in enumerate(edge_ccs):
# for node in cc:
# edge_canvas[node[0], node[1]] = cc_id
# f, ((ax0, ax1, ax2)) = plt.subplots(1, 3, sharex=True, sharey=True); ax0.imshow(1./depth); ax1.imshow(image); ax2.imshow(edge_canvas); plt.show()
input_mesh, info_on_pix, specific_edge_nodes, new_edge_ccs, connect_points_ccs, image = DL_inpaint_edge(input_mesh,
info_on_pix,
config,
image,
depth,
context_ccs,
erode_context_ccs,
extend_context_ccs,
extend_erode_context_ccs,
mask_ccs,
broken_mask_ccs,
edge_ccs,
extend_edge_ccs,
init_mask_connect,
edge_maps,
rgb_model,
depth_edge_model,
depth_edge_model_init,
depth_feat_model,
specific_edge_id,
specific_edge_loc,
inpaint_iter=0)
specific_edge_id = []
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W']))
connect_points_ccs = [set() for _ in connect_points_ccs]
context_ccs, mask_ccs, broken_mask_ccs, edge_ccs, erode_context_ccs, init_mask_connect, \
edge_maps, extend_context_ccs, extend_edge_ccs, extend_erode_context_ccs = \
context_and_holes(input_mesh, new_edge_ccs, config, specific_edge_id, specific_edge_loc, depth_feat_model, connect_points_ccs, inpaint_iter=1)
mask_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W']))
context_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W']))
erode_context_ccs_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W']))
edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W']))
# edge_canvas = np.zeros((input_mesh.graph['H'], input_mesh.graph['W'])) - 1
# for cc_id, cc in enumerate(edge_ccs):
# for node in cc:
# edge_canvas[node[0], node[1]] = cc_id
specific_edge_id = []
input_mesh, info_on_pix, specific_edge_nodes, new_edge_ccs, _, image = DL_inpaint_edge(input_mesh,
info_on_pix,
config,
image,
depth,
context_ccs,
erode_context_ccs,
extend_context_ccs,
extend_erode_context_ccs,
mask_ccs,
broken_mask_ccs,
edge_ccs,
extend_edge_ccs,
init_mask_connect,
edge_maps,
rgb_model,
depth_edge_model,
depth_edge_model_init,
depth_feat_model,
specific_edge_id,
specific_edge_loc,
inpaint_iter=1)
vertex_id = 0
input_mesh.graph['H'], input_mesh.graph['W'] = input_mesh.graph['noext_H'], input_mesh.graph['noext_W']
background_canvas = np.zeros((input_mesh.graph['H'],
input_mesh.graph['W'],
3))
ply_flag = config.get('save_ply')
if ply_flag is True:
node_str_list = []
else:
node_str_color = []
node_str_point = []
out_fmt = lambda x, x_flag: str(x) if x_flag is True else x
point_time = 0
hlight_time = 0
cur_id_time = 0
node_str_time = 0
generate_face_time = 0
point_list = []
k_00, k_02, k_11, k_12 = \
input_mesh.graph['cam_param_pix_inv'][0, 0], input_mesh.graph['cam_param_pix_inv'][0, 2], \
input_mesh.graph['cam_param_pix_inv'][1, 1], input_mesh.graph['cam_param_pix_inv'][1, 2]
w_offset = input_mesh.graph['woffset']
h_offset = input_mesh.graph['hoffset']
for pix_xy, pix_list in info_on_pix.items():
for pix_idx, pix_info in enumerate(pix_list):
pix_depth = pix_info['depth'] if pix_info.get('real_depth') is None else pix_info['real_depth']
str_pt = [out_fmt(x, ply_flag) for x in reproject_3d_int_detail(pix_xy[0], pix_xy[1], pix_depth,
k_00, k_02, k_11, k_12, w_offset, h_offset)]
if input_mesh.has_node((pix_xy[0], pix_xy[1], pix_info['depth'])) is False:
return False
continue
if pix_info.get('overlap_number') is not None:
str_color = [out_fmt(x, ply_flag) for x in (pix_info['color']/pix_info['overlap_number']).astype(np.uint8).tolist()]
else:
str_color = [out_fmt(x, ply_flag) for x in pix_info['color'].tolist()]
if pix_info.get('edge_occlusion') is True:
str_color.append(out_fmt(4, ply_flag))
else:
if pix_info.get('inpaint_id') is None:
str_color.append(out_fmt(1, ply_flag))
else:
str_color.append(out_fmt(pix_info.get('inpaint_id') + 1, ply_flag))
if pix_info.get('modified_border') is True or pix_info.get('ext_pixel') is True:
if len(str_color) == 4:
str_color[-1] = out_fmt(5, ply_flag)
else:
str_color.append(out_fmt(5, ply_flag))
pix_info['cur_id'] = vertex_id
input_mesh.nodes[(pix_xy[0], pix_xy[1], pix_info['depth'])]['cur_id'] = out_fmt(vertex_id, ply_flag)
vertex_id += 1
if ply_flag is True:
node_str_list.append(' '.join(str_pt) + ' ' + ' '.join(str_color) + '\n')
else:
node_str_color.append(str_color)
node_str_point.append(str_pt)
str_faces = generate_face(input_mesh, info_on_pix, config)
if config['save_ply'] is True:
print("Writing mesh file %s ..." % ply_name)
with open(ply_name, 'w') as ply_fi:
ply_fi.write('ply\n' + 'format ascii 1.0\n')
ply_fi.write('comment H ' + str(int(input_mesh.graph['H'])) + '\n')
ply_fi.write('comment W ' + str(int(input_mesh.graph['W'])) + '\n')
ply_fi.write('comment hFov ' + str(float(input_mesh.graph['hFov'])) + '\n')
ply_fi.write('comment vFov ' + str(float(input_mesh.graph['vFov'])) + '\n')
ply_fi.write('element vertex ' + str(len(node_str_list)) + '\n')
ply_fi.write('property float x\n' + \
'property float y\n' + \
'property float z\n' + \
'property uchar red\n' + \
'property uchar green\n' + \
'property uchar blue\n' + \
'property uchar alpha\n')
ply_fi.write('element face ' + str(len(str_faces)) + '\n')
ply_fi.write('property list uchar int vertex_index\n')
ply_fi.write('end_header\n')
ply_fi.writelines(node_str_list)
ply_fi.writelines(str_faces)
ply_fi.close()
return input_mesh
else:
H = int(input_mesh.graph['H'])
W = int(input_mesh.graph['W'])
hFov = input_mesh.graph['hFov']
vFov = input_mesh.graph['vFov']
node_str_color = np.array(node_str_color).astype(np.float32)
node_str_color[..., :3] = node_str_color[..., :3] / 255.
node_str_point = np.array(node_str_point)
str_faces = np.array(str_faces)
return node_str_point, node_str_color, str_faces, H, W, hFov, vFov
def read_ply(mesh_fi):
ply_fi = open(mesh_fi, 'r')
Height = None
Width = None
hFov = None
vFov = None
while True:
line = ply_fi.readline().split('\n')[0]
if line.startswith('element vertex'):
num_vertex = int(line.split(' ')[-1])
elif line.startswith('element face'):
num_face = int(line.split(' ')[-1])
elif line.startswith('comment'):
if line.split(' ')[1] == 'H':
Height = int(line.split(' ')[-1].split('\n')[0])
if line.split(' ')[1] == 'W':
Width = int(line.split(' ')[-1].split('\n')[0])
if line.split(' ')[1] == 'hFov':
hFov = float(line.split(' ')[-1].split('\n')[0])
if line.split(' ')[1] == 'vFov':
vFov = float(line.split(' ')[-1].split('\n')[0])
elif line.startswith('end_header'):
break
contents = ply_fi.readlines()
vertex_infos = contents[:num_vertex]
face_infos = contents[num_vertex:]
verts = []
colors = []
faces = []
for v_info in vertex_infos:
str_info = [float(v) for v in v_info.split('\n')[0].split(' ')]
if len(str_info) == 6:
vx, vy, vz, r, g, b = str_info
else:
vx, vy, vz, r, g, b, hi = str_info
verts.append([vx, vy, vz])
colors.append([r, g, b, hi])
verts = np.array(verts)
try:
colors = np.array(colors)
colors[..., :3] = colors[..., :3]/255.
except:
import pdb
pdb.set_trace()
for f_info in face_infos:
_, v1, v2, v3 = [int(f) for f in f_info.split('\n')[0].split(' ')]
faces.append([v1, v2, v3])
faces = np.array(faces)
return verts, colors, faces, Height, Width, hFov, vFov
class Canvas_view():
def __init__(self,
fov,
verts,
faces,
colors,
canvas_size,
factor=1,
bgcolor='gray',
proj='perspective',
):
self.canvas = scene.SceneCanvas(bgcolor=bgcolor, size=(canvas_size*factor, canvas_size*factor))
self.view = self.canvas.central_widget.add_view()
self.view.camera = 'perspective'
self.view.camera.fov = fov
self.mesh = visuals.Mesh(shading=None)
self.mesh.attach(Alpha(1.0))
self.view.add(self.mesh)
self.tr = self.view.camera.transform
self.mesh.set_data(vertices=verts, faces=faces, vertex_colors=colors[:, :3])
self.translate([0,0,0])
self.rotate(axis=[1,0,0], angle=180)
self.view_changed()
def translate(self, trans=[0,0,0]):
self.tr.translate(trans)
def rotate(self, axis=[1,0,0], angle=0):
self.tr.rotate(axis=axis, angle=angle)
def view_changed(self):
self.view.camera.view_changed()
def render(self):
return self.canvas.render()
def reinit_mesh(self, verts, faces, colors):
self.mesh.set_data(vertices=verts, faces=faces, vertex_colors=colors[:, :3])
def reinit_camera(self, fov):
self.view.camera.fov = fov
self.view.camera.view_changed()
def output_3d_photo(verts, colors, faces, Height, Width, hFov, vFov, tgt_poses, video_traj_types, ref_pose,
output_dir, ref_image, int_mtx, config, image, videos_poses, video_basename, original_H=None, original_W=None,
border=None, depth=None, normal_canvas=None, all_canvas=None, mean_loc_depth=None):
cam_mesh = netx.Graph()
cam_mesh.graph['H'] = Height
cam_mesh.graph['W'] = Width
cam_mesh.graph['original_H'] = original_H
cam_mesh.graph['original_W'] = original_W
int_mtx_real_x = int_mtx[0] * Width
int_mtx_real_y = int_mtx[1] * Height
cam_mesh.graph['hFov'] = 2 * np.arctan((1. / 2.) * ((cam_mesh.graph['original_W']) / int_mtx_real_x[0]))
cam_mesh.graph['vFov'] = 2 * np.arctan((1. / 2.) * ((cam_mesh.graph['original_H']) / int_mtx_real_y[1]))
colors = colors[..., :3]
fov_in_rad = max(cam_mesh.graph['vFov'], cam_mesh.graph['hFov'])
fov = (fov_in_rad * 180 / np.pi)
print("fov: " + str(fov))
init_factor = 1
if config.get('anti_flickering') is True:
init_factor = 3
if (cam_mesh.graph['original_H'] is not None) and (cam_mesh.graph['original_W'] is not None):
canvas_w = cam_mesh.graph['original_W']
canvas_h = cam_mesh.graph['original_H']
else:
canvas_w = cam_mesh.graph['W']
canvas_h = cam_mesh.graph['H']
canvas_size = max(canvas_h, canvas_w)
if normal_canvas is None:
normal_canvas = Canvas_view(fov,
verts,
faces,
colors,
canvas_size=canvas_size,
factor=init_factor,
bgcolor='gray',
proj='perspective')
else:
normal_canvas.reinit_mesh(verts, faces, colors)
normal_canvas.reinit_camera(fov)
img = normal_canvas.render()
backup_img, backup_all_img, all_img_wo_bound = img.copy(), img.copy() * 0, img.copy() * 0
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_AREA)
if border is None:
border = [0, img.shape[0], 0, img.shape[1]]
H, W = cam_mesh.graph['H'], cam_mesh.graph['W']
if (cam_mesh.graph['original_H'] is not None) and (cam_mesh.graph['original_W'] is not None):
aspect_ratio = cam_mesh.graph['original_H'] / cam_mesh.graph['original_W']
else:
aspect_ratio = cam_mesh.graph['H'] / cam_mesh.graph['W']
if aspect_ratio > 1:
img_h_len = cam_mesh.graph['H'] if cam_mesh.graph.get('original_H') is None else cam_mesh.graph['original_H']
img_w_len = img_h_len / aspect_ratio
anchor = [0,
img.shape[0],
int(max(0, int((img.shape[1])//2 - img_w_len//2))),
int(min(int((img.shape[1])//2 + img_w_len//2), (img.shape[1])-1))]
elif aspect_ratio <= 1:
img_w_len = cam_mesh.graph['W'] if cam_mesh.graph.get('original_W') is None else cam_mesh.graph['original_W']
img_h_len = img_w_len * aspect_ratio
anchor = [int(max(0, int((img.shape[0])//2 - img_h_len//2))),
int(min(int((img.shape[0])//2 + img_h_len//2), (img.shape[0])-1)),
0,
img.shape[1]]
anchor = np.array(anchor)
plane_width = np.tan(fov_in_rad/2.) * np.abs(mean_loc_depth)
for video_pose, video_traj_type in zip(videos_poses, video_traj_types):
stereos = []
tops = []; buttoms = []; lefts = []; rights = []
for tp_id, tp in enumerate(video_pose):
rel_pose = np.linalg.inv(np.dot(tp, np.linalg.inv(ref_pose)))
axis, angle = transforms3d.axangles.mat2axangle(rel_pose[0:3, 0:3])
normal_canvas.rotate(axis=axis, angle=(angle*180)/np.pi)
normal_canvas.translate(rel_pose[:3,3])
new_mean_loc_depth = mean_loc_depth - float(rel_pose[2, 3])
if 'dolly' in video_traj_type:
new_fov = float((np.arctan2(plane_width, np.array([np.abs(new_mean_loc_depth)])) * 180. / np.pi) * 2)
normal_canvas.reinit_camera(new_fov)
else:
normal_canvas.reinit_camera(fov)
normal_canvas.view_changed()
img = normal_canvas.render()
img = cv2.GaussianBlur(img,(int(init_factor//2 * 2 + 1), int(init_factor//2 * 2 + 1)), 0)
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_AREA)
img = img[anchor[0]:anchor[1], anchor[2]:anchor[3]]
img = img[int(border[0]):int(border[1]), int(border[2]):int(border[3])]
if any(np.array(config['crop_border']) > 0.0):
H_c, W_c, _ = img.shape
o_t = int(H_c * config['crop_border'][0])
o_l = int(W_c * config['crop_border'][1])
o_b = int(H_c * config['crop_border'][2])
o_r = int(W_c * config['crop_border'][3])
img = img[o_t:H_c-o_b, o_l:W_c-o_r]
img = cv2.resize(img, (W_c, H_c), interpolation=cv2.INTER_CUBIC)
"""
img = cv2.resize(img, (int(img.shape[1] / init_factor), int(img.shape[0] / init_factor)), interpolation=cv2.INTER_CUBIC)
img = img[anchor[0]:anchor[1], anchor[2]:anchor[3]]
img = img[int(border[0]):int(border[1]), int(border[2]):int(border[3])]
if config['crop_border'] is True:
top, buttom, left, right = find_largest_rect(img, bg_color=(128, 128, 128))
tops.append(top); buttoms.append(buttom); lefts.append(left); rights.append(right)
"""
stereos.append(img[..., :3])
normal_canvas.translate(-rel_pose[:3,3])
normal_canvas.rotate(axis=axis, angle=-(angle*180)/np.pi)
normal_canvas.view_changed()
"""
if config['crop_border'] is True:
atop, abuttom = min(max(tops), img.shape[0]//2 - 10), max(min(buttoms), img.shape[0]//2 + 10)
aleft, aright = min(max(lefts), img.shape[1]//2 - 10), max(min(rights), img.shape[1]//2 + 10)
atop -= atop % 2; abuttom -= abuttom % 2; aleft -= aleft % 2; aright -= aright % 2
else:
atop = 0; abuttom = img.shape[0] - img.shape[0] % 2; aleft = 0; aright = img.shape[1] - img.shape[1] % 2
"""
atop = 0; abuttom = img.shape[0] - img.shape[0] % 2; aleft = 0; aright = img.shape[1] - img.shape[1] % 2
crop_stereos = []
for stereo in stereos:
crop_stereos.append((stereo[atop:abuttom, aleft:aright, :3] * 1).astype(np.uint8))
stereos = crop_stereos
clip = ImageSequenceClip(stereos, fps=config['fps'])
if isinstance(video_basename, list):
video_basename = video_basename[0]
clip.write_videofile(os.path.join(output_dir, video_basename + '_' + video_traj_type + '.mp4'), fps=config['fps'])
return normal_canvas, all_canvas