import os import cv2 import glob import json import tqdm import numpy as np from scipy.spatial.transform import Slerp, Rotation import matplotlib.pyplot as plt import trimesh import torch import torch.nn.functional as F from torch.utils.data import DataLoader from .utils import get_audio_features, get_rays, get_bg_coords, convert_poses, AudDataset # ref: https://github.com/NVlabs/instant-ngp/blob/b76004c8cf478880227401ae763be4c02f80b62f/include/neural-graphics-primitives/nerf_loader.h#L50 def nerf_matrix_to_ngp(pose, scale=0.33, offset=[0, 0, 0]): new_pose = np.array([ [pose[1, 0], -pose[1, 1], -pose[1, 2], pose[1, 3] * scale + offset[0]], [pose[2, 0], -pose[2, 1], -pose[2, 2], pose[2, 3] * scale + offset[1]], [pose[0, 0], -pose[0, 1], -pose[0, 2], pose[0, 3] * scale + offset[2]], [0, 0, 0, 1], ], dtype=np.float32) return new_pose def smooth_camera_path(poses, kernel_size=5): # smooth the camera trajectory... # poses: [N, 4, 4], numpy array N = poses.shape[0] K = kernel_size // 2 trans = poses[:, :3, 3].copy() # [N, 3] rots = poses[:, :3, :3].copy() # [N, 3, 3] for i in range(N): start = max(0, i - K) end = min(N, i + K + 1) poses[i, :3, 3] = trans[start:end].mean(0) poses[i, :3, :3] = Rotation.from_matrix(rots[start:end]).mean().as_matrix() return poses def polygon_area(x, y): x_ = x - x.mean() y_ = y - y.mean() correction = x_[-1] * y_[0] - y_[-1]* x_[0] main_area = np.dot(x_[:-1], y_[1:]) - np.dot(y_[:-1], x_[1:]) return 0.5 * np.abs(main_area + correction) def visualize_poses(poses, size=0.1): # poses: [B, 4, 4] print(f'[INFO] visualize poses: {poses.shape}') axes = trimesh.creation.axis(axis_length=4) box = trimesh.primitives.Box(extents=(2, 2, 2)).as_outline() box.colors = np.array([[128, 128, 128]] * len(box.entities)) objects = [axes, box] for pose in poses: # a camera is visualized with 8 line segments. pos = pose[:3, 3] a = pos + size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] b = pos - size * pose[:3, 0] + size * pose[:3, 1] + size * pose[:3, 2] c = pos - size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] d = pos + size * pose[:3, 0] - size * pose[:3, 1] + size * pose[:3, 2] dir = (a + b + c + d) / 4 - pos dir = dir / (np.linalg.norm(dir) + 1e-8) o = pos + dir * 3 segs = np.array([[pos, a], [pos, b], [pos, c], [pos, d], [a, b], [b, c], [c, d], [d, a], [pos, o]]) segs = trimesh.load_path(segs) objects.append(segs) trimesh.Scene(objects).show() from .wav2vec import * class NeRFDataset_Test: def __init__(self, opt, device, downscale=1): super().__init__() self.opt = opt self.device = device self.downscale = downscale self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. self.offset = opt.offset # camera offset self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. self.fp16 = opt.fp16 self.start_index = opt.data_range[0] self.end_index = opt.data_range[1] self.training = False self.num_rays = -1 # load nerf-compatible format data. with open(opt.pose, 'r') as f: transform = json.load(f) # load image size self.H = int(transform['cy']) * 2 // downscale self.W = int(transform['cx']) * 2 // downscale # read images frames = transform["frames"] # use a slice of the dataset if self.end_index == -1: # abuse... self.end_index = len(frames) frames = frames[self.start_index:self.end_index] print(f'[INFO] load {len(frames)} frames.') # only load pre-calculated aud features when not live-streaming if not self.opt.asr: if self.opt.aud.endswith('npy'): aud_features = np.load(self.opt.aud) elif self.opt.aud.endswith('wav'): if self.opt.asr_model == 'cpierse/wav2vec2-large-xlsr-53-esperanto': import argparse parser = argparse.ArgumentParser() parser.add_argument('--wav', type=str, default='') parser.add_argument('--play', action='store_true', help="play out the audio") parser.add_argument('--model', type=str, default='cpierse/wav2vec2-large-xlsr-53-esperanto') # parser.add_argument('--model', type=str, default='facebook/wav2vec2-large-960h-lv60-self') parser.add_argument('--save_feats', action='store_true') # audio FPS parser.add_argument('--fps', type=int, default=50) # sliding window left-middle-right length. parser.add_argument('-l', type=int, default=10) parser.add_argument('-m', type=int, default=50) parser.add_argument('-r', type=int, default=10) opt = parser.parse_args() # fix opt.asr_wav = self.opt.aud opt.asr_play = opt.play opt.asr_save_feats = True opt.asr_model = opt.model # 利用预训练的Wav2vec来跑一下 with ASR(opt) as asr: asr.run() # os.system(f"ls") # os.system(f"python NeRF/nerf_triplane/wav2vec.py --wav {self.opt.aud} --save_feats") aud_features = np.load(opt.aud.replace('.wav', '_eo.npy')) elif self.opt.asr_model == 'deepspeech': os.system(f"python NeRF/data_utils/deepspeech_features/extract_ds_features.py --input {self.opt.aud}") aud_features = np.load(opt.aud.replace('.wav', '.npy')) elif self.opt.asr_model == 'ave': from .network import AudioEncoder device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = AudioEncoder().to(device).eval() ckpt = torch.load('./checkpoints/audio_visual_encoder.pth') model.load_state_dict({f'audio_encoder.{k}': v for k, v in ckpt.items()}) dataset = AudDataset(self.opt.aud) data_loader = DataLoader(dataset, batch_size=64, shuffle=False) outputs = [] for mel in data_loader: mel = mel.to(device) with torch.no_grad(): out = model(mel) outputs.append(out) outputs = torch.cat(outputs, dim=0).cpu() first_frame, last_frame = outputs[:1], outputs[-1:] aud_features = torch.cat([first_frame.repeat(2, 1), outputs, last_frame.repeat(2, 1)], dim=0).numpy() else: try: aud_features = np.load(self.opt.aud) except: print(f'[ERROR] If do not use Audio Visual Encoder, replace it with the npy file path') else: raise NotImplementedError if self.opt.asr_model == 'ave': aud_features = torch.from_numpy(aud_features).unsqueeze(0) # support both [N, 16] labels and [N, 16, K] logits if len(aud_features.shape) == 3: aud_features = aud_features.float().permute(1, 0, 2) # [N, 16, 29] --> [N, 29, 16] if self.opt.emb: print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') aud_features = aud_features.argmax(1) # [N, 16] else: assert self.opt.emb, "aud only provide labels, must use --emb" aud_features = aud_features.long() print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') else: aud_features = torch.from_numpy(aud_features) # support both [N, 16] labels and [N, 16, K] logits if len(aud_features.shape) == 3: aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16] if self.opt.emb: print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') aud_features = aud_features.argmax(1) # [N, 16] else: assert self.opt.emb, "aud only provide labels, must use --emb" aud_features = aud_features.long() print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') self.poses = [] self.auds = [] self.eye_area = [] for f in tqdm.tqdm(frames, desc=f'Loading data'): pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) self.poses.append(pose) # find the corresponding audio to the image frame if not self.opt.asr and self.opt.aud == '': aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame... self.auds.append(aud) if self.opt.exp_eye: if 'eye_ratio' in f: area = f['eye_ratio'] else: area = 0.25 # default value for opened eye self.eye_area.append(area) # load pre-extracted background image (should be the same size as training image...) if self.opt.bg_img == 'white': # special bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) elif self.opt.bg_img == 'black': # special bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) else: # load from file bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3] if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] self.bg_img = bg_img self.poses = np.stack(self.poses, axis=0) # smooth camera path... if self.opt.smooth_path: self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window) self.poses = torch.from_numpy(self.poses) # [N, 4, 4] if self.opt.asr: # live streaming, no pre-calculated auds self.auds = None else: # auds corresponding to images if self.opt.aud == '': self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16] # auds is novel, may have a different length with images else: self.auds = aud_features self.bg_img = torch.from_numpy(self.bg_img) if self.opt.exp_eye: self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') if self.opt.smooth_eye: # naive 5 window average ori_eye = self.eye_area.copy() for i in range(ori_eye.shape[0]): start = max(0, i - 1) end = min(ori_eye.shape[0], i + 2) self.eye_area[i] = ori_eye[start:end].mean() self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1] # always preload self.poses = self.poses.to(self.device) if self.auds is not None: self.auds = self.auds.to(self.device) self.bg_img = self.bg_img.to(torch.half).to(self.device) if self.opt.exp_eye: self.eye_area = self.eye_area.to(self.device) # load intrinsics fl_x = fl_y = transform['focal_len'] cx = (transform['cx'] / downscale) cy = (transform['cy'] / downscale) self.intrinsics = np.array([fl_x, fl_y, cx, cy]) # directly build the coordinate meshgrid in [-1, 1]^2 self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1] def mirror_index(self, index): size = self.poses.shape[0] turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 def collate(self, index): B = len(index) # a list of length 1 # assert B == 1 results = {} # audio use the original index if self.auds is not None: auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device) results['auds'] = auds # head pose and bg image may mirror (replay --> <-- --> <--). index[0] = self.mirror_index(index[0]) poses = self.poses[index].to(self.device) # [B, 4, 4] rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size) results['index'] = index # for ind. code results['H'] = self.H results['W'] = self.W results['rays_o'] = rays['rays_o'] results['rays_d'] = rays['rays_d'] if self.opt.exp_eye: results['eye'] = self.eye_area[index].to(self.device) # [1] else: results['eye'] = None bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) results['bg_color'] = bg_img bg_coords = self.bg_coords # [1, N, 2] results['bg_coords'] = bg_coords # results['poses'] = convert_poses(poses) # [B, 6] # results['poses_matrix'] = poses # [B, 4, 4] results['poses'] = poses # [B, 4, 4] return results def dataloader(self): # test with novel auds, then use its length if self.auds is not None: size = self.auds.shape[0] # live stream test, use 2 * len(poses), so it naturally mirrors. else: size = 2 * self.poses.shape[0] loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=False, num_workers=0) loader._data = self # an ugly fix... we need poses in trainer. # do evaluate if has gt images and use self-driven setting loader.has_gt = False return loader class NeRFDataset: def __init__(self, opt, device, type='train', downscale=1): super().__init__() self.opt = opt self.device = device self.type = type # train, val, test self.downscale = downscale self.root_path = opt.path self.preload = opt.preload # 0 = disk, 1 = cpu, 2 = gpu self.scale = opt.scale # camera radius scale to make sure camera are inside the bounding box. self.offset = opt.offset # camera offset self.bound = opt.bound # bounding box half length, also used as the radius to random sample poses. self.fp16 = opt.fp16 self.start_index = opt.data_range[0] self.end_index = opt.data_range[1] self.training = self.type in ['train', 'all', 'trainval'] self.num_rays = self.opt.num_rays if self.training else -1 # load nerf-compatible format data. # load all splits (train/valid/test) if type == 'all': transform_paths = glob.glob(os.path.join(self.root_path, '*.json')) transform = None for transform_path in transform_paths: with open(transform_path, 'r') as f: tmp_transform = json.load(f) if transform is None: transform = tmp_transform else: transform['frames'].extend(tmp_transform['frames']) # load train and val split elif type == 'trainval': with open(os.path.join(self.root_path, f'transforms_train.json'), 'r') as f: transform = json.load(f) with open(os.path.join(self.root_path, f'transforms_val.json'), 'r') as f: transform_val = json.load(f) transform['frames'].extend(transform_val['frames']) # only load one specified split else: # no test, use val as test _split = 'val' if type == 'test' else type with open(os.path.join(self.root_path, f'transforms_{_split}.json'), 'r') as f: transform = json.load(f) # load image size if 'h' in transform and 'w' in transform: self.H = int(transform['h']) // downscale self.W = int(transform['w']) // downscale else: self.H = int(transform['cy']) * 2 // downscale self.W = int(transform['cx']) * 2 // downscale # read images frames = transform["frames"] # use a slice of the dataset if self.end_index == -1: # abuse... self.end_index = len(frames) frames = frames[self.start_index:self.end_index] # use a subset of dataset. if type == 'train': if self.opt.part: frames = frames[::10] # 1/10 frames elif self.opt.part2: frames = frames[:375] # first 15s elif type == 'val': frames = frames[:100] # first 100 frames for val print(f'[INFO] load {len(frames)} {type} frames.') # only load pre-calculated aud features when not live-streaming if not self.opt.asr: # empty means the default self-driven extracted features. if self.opt.aud == '': if 'esperanto' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_eo.npy')) elif 'deepspeech' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_ds.npy')) # elif 'hubert_cn' in self.opt.asr_model: # aud_features = np.load(os.path.join(self.root_path, 'aud_hu_cn.npy')) elif 'hubert' in self.opt.asr_model: aud_features = np.load(os.path.join(self.root_path, 'aud_hu.npy')) else: aud_features = np.load(os.path.join(self.root_path, 'aud.npy')) # cross-driven extracted features. else: aud_features = np.load(self.opt.aud) aud_features = torch.from_numpy(aud_features) # support both [N, 16] labels and [N, 16, K] logits if len(aud_features.shape) == 3: aud_features = aud_features.float().permute(0, 2, 1) # [N, 16, 29] --> [N, 29, 16] if self.opt.emb: print(f'[INFO] argmax to aud features {aud_features.shape} for --emb mode') aud_features = aud_features.argmax(1) # [N, 16] else: assert self.opt.emb, "aud only provide labels, must use --emb" aud_features = aud_features.long() print(f'[INFO] load {self.opt.aud} aud_features: {aud_features.shape}') # load action units import pandas as pd au_blink_info=pd.read_csv(os.path.join(self.root_path, 'au.csv')) au_blink = au_blink_info[' AU45_r'].values self.torso_img = [] self.images = [] self.poses = [] self.exps = [] self.auds = [] self.face_rect = [] self.lhalf_rect = [] self.lips_rect = [] self.eye_area = [] self.eye_rect = [] for f in tqdm.tqdm(frames, desc=f'Loading {type} data'): f_path = os.path.join(self.root_path, 'gt_imgs', str(f['img_id']) + '.jpg') if not os.path.exists(f_path): print('[WARN]', f_path, 'NOT FOUND!') continue pose = np.array(f['transform_matrix'], dtype=np.float32) # [4, 4] pose = nerf_matrix_to_ngp(pose, scale=self.scale, offset=self.offset) self.poses.append(pose) if self.preload > 0: image = cv2.imread(f_path, cv2.IMREAD_UNCHANGED) # [H, W, 3] o [H, W, 4] image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image = image.astype(np.float32) / 255 # [H, W, 3/4] self.images.append(image) else: self.images.append(f_path) # load frame-wise bg torso_img_path = os.path.join(self.root_path, 'torso_imgs', str(f['img_id']) + '.png') if self.preload > 0: torso_img = cv2.imread(torso_img_path, cv2.IMREAD_UNCHANGED) # [H, W, 4] torso_img = cv2.cvtColor(torso_img, cv2.COLOR_BGRA2RGBA) torso_img = torso_img.astype(np.float32) / 255 # [H, W, 3/4] self.torso_img.append(torso_img) else: self.torso_img.append(torso_img_path) # find the corresponding audio to the image frame if not self.opt.asr and self.opt.aud == '': aud = aud_features[min(f['aud_id'], aud_features.shape[0] - 1)] # careful for the last frame... self.auds.append(aud) # load lms and extract face lms = np.loadtxt(os.path.join(self.root_path, 'ori_imgs', str(f['img_id']) + '.lms')) # [68, 2] lh_xmin, lh_xmax = int(lms[31:36, 1].min()), int(lms[:, 1].max()) # actually lower half area xmin, xmax = int(lms[:, 1].min()), int(lms[:, 1].max()) ymin, ymax = int(lms[:, 0].min()), int(lms[:, 0].max()) self.face_rect.append([xmin, xmax, ymin, ymax]) self.lhalf_rect.append([lh_xmin, lh_xmax, ymin, ymax]) if self.opt.exp_eye: # eyes_left = slice(36, 42) # eyes_right = slice(42, 48) # area_left = polygon_area(lms[eyes_left, 0], lms[eyes_left, 1]) # area_right = polygon_area(lms[eyes_right, 0], lms[eyes_right, 1]) # # area percentage of two eyes of the whole image... # area = (area_left + area_right) / (self.H * self.W) * 100 # action units blink AU45 area = au_blink[f['img_id']] area = np.clip(area, 0, 2) / 2 # area = area + np.random.rand() / 10 self.eye_area.append(area) xmin, xmax = int(lms[36:48, 1].min()), int(lms[36:48, 1].max()) ymin, ymax = int(lms[36:48, 0].min()), int(lms[36:48, 0].max()) self.eye_rect.append([xmin, xmax, ymin, ymax]) if self.opt.finetune_lips: lips = slice(48, 60) xmin, xmax = int(lms[lips, 1].min()), int(lms[lips, 1].max()) ymin, ymax = int(lms[lips, 0].min()), int(lms[lips, 0].max()) # padding to H == W cx = (xmin + xmax) // 2 cy = (ymin + ymax) // 2 l = max(xmax - xmin, ymax - ymin) // 2 xmin = max(0, cx - l) xmax = min(self.H, cx + l) ymin = max(0, cy - l) ymax = min(self.W, cy + l) self.lips_rect.append([xmin, xmax, ymin, ymax]) # load pre-extracted background image (should be the same size as training image...) if self.opt.bg_img == 'white': # special bg_img = np.ones((self.H, self.W, 3), dtype=np.float32) elif self.opt.bg_img == 'black': # special bg_img = np.zeros((self.H, self.W, 3), dtype=np.float32) else: # load from file # default bg if self.opt.bg_img == '': self.opt.bg_img = os.path.join(self.root_path, 'bc.jpg') bg_img = cv2.imread(self.opt.bg_img, cv2.IMREAD_UNCHANGED) # [H, W, 3] if bg_img.shape[0] != self.H or bg_img.shape[1] != self.W: bg_img = cv2.resize(bg_img, (self.W, self.H), interpolation=cv2.INTER_AREA) bg_img = cv2.cvtColor(bg_img, cv2.COLOR_BGR2RGB) bg_img = bg_img.astype(np.float32) / 255 # [H, W, 3/4] self.bg_img = bg_img self.poses = np.stack(self.poses, axis=0) # smooth camera path... if self.opt.smooth_path: self.poses = smooth_camera_path(self.poses, self.opt.smooth_path_window) self.poses = torch.from_numpy(self.poses) # [N, 4, 4] if self.preload > 0: self.images = torch.from_numpy(np.stack(self.images, axis=0)) # [N, H, W, C] self.torso_img = torch.from_numpy(np.stack(self.torso_img, axis=0)) # [N, H, W, C] else: self.images = np.array(self.images) self.torso_img = np.array(self.torso_img) if self.opt.asr: # live streaming, no pre-calculated auds self.auds = None else: # auds corresponding to images if self.opt.aud == '': self.auds = torch.stack(self.auds, dim=0) # [N, 32, 16] # auds is novel, may have a different length with images else: self.auds = aud_features self.bg_img = torch.from_numpy(self.bg_img) if self.opt.exp_eye: self.eye_area = np.array(self.eye_area, dtype=np.float32) # [N] print(f'[INFO] eye_area: {self.eye_area.min()} - {self.eye_area.max()}') if self.opt.smooth_eye: # naive 5 window average ori_eye = self.eye_area.copy() for i in range(ori_eye.shape[0]): start = max(0, i - 1) end = min(ori_eye.shape[0], i + 2) self.eye_area[i] = ori_eye[start:end].mean() self.eye_area = torch.from_numpy(self.eye_area).view(-1, 1) # [N, 1] # calculate mean radius of all camera poses self.radius = self.poses[:, :3, 3].norm(dim=-1).mean(0).item() #print(f'[INFO] dataset camera poses: radius = {self.radius:.4f}, bound = {self.bound}') # [debug] uncomment to view all training poses. # visualize_poses(self.poses.numpy()) # [debug] uncomment to view examples of randomly generated poses. # visualize_poses(rand_poses(100, self.device, radius=self.radius).cpu().numpy()) if self.preload > 1: self.poses = self.poses.to(self.device) if self.auds is not None: self.auds = self.auds.to(self.device) self.bg_img = self.bg_img.to(torch.half).to(self.device) self.torso_img = self.torso_img.to(torch.half).to(self.device) self.images = self.images.to(torch.half).to(self.device) if self.opt.exp_eye: self.eye_area = self.eye_area.to(self.device) # load intrinsics if 'focal_len' in transform: fl_x = fl_y = transform['focal_len'] elif 'fl_x' in transform or 'fl_y' in transform: fl_x = (transform['fl_x'] if 'fl_x' in transform else transform['fl_y']) / downscale fl_y = (transform['fl_y'] if 'fl_y' in transform else transform['fl_x']) / downscale elif 'camera_angle_x' in transform or 'camera_angle_y' in transform: # blender, assert in radians. already downscaled since we use H/W fl_x = self.W / (2 * np.tan(transform['camera_angle_x'] / 2)) if 'camera_angle_x' in transform else None fl_y = self.H / (2 * np.tan(transform['camera_angle_y'] / 2)) if 'camera_angle_y' in transform else None if fl_x is None: fl_x = fl_y if fl_y is None: fl_y = fl_x else: raise RuntimeError('Failed to load focal length, please check the transforms.json!') cx = (transform['cx'] / downscale) if 'cx' in transform else (self.W / 2) cy = (transform['cy'] / downscale) if 'cy' in transform else (self.H / 2) self.intrinsics = np.array([fl_x, fl_y, cx, cy]) # directly build the coordinate meshgrid in [-1, 1]^2 self.bg_coords = get_bg_coords(self.H, self.W, self.device) # [1, H*W, 2] in [-1, 1] def mirror_index(self, index): size = self.poses.shape[0] turn = index // size res = index % size if turn % 2 == 0: return res else: return size - res - 1 def collate(self, index): B = len(index) # a list of length 1 # assert B == 1 results = {} # audio use the original index if self.auds is not None: auds = get_audio_features(self.auds, self.opt.att, index[0]).to(self.device) results['auds'] = auds # head pose and bg image may mirror (replay --> <-- --> <--). index[0] = self.mirror_index(index[0]) poses = self.poses[index].to(self.device) # [B, 4, 4] if self.training and self.opt.finetune_lips: rect = self.lips_rect[index[0]] results['rect'] = rect rays = get_rays(poses, self.intrinsics, self.H, self.W, -1, rect=rect) else: rays = get_rays(poses, self.intrinsics, self.H, self.W, self.num_rays, self.opt.patch_size) results['index'] = index # for ind. code results['H'] = self.H results['W'] = self.W results['rays_o'] = rays['rays_o'] results['rays_d'] = rays['rays_d'] # get a mask for rays inside rect_face if self.training: xmin, xmax, ymin, ymax = self.face_rect[index[0]] face_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['face_mask'] = face_mask xmin, xmax, ymin, ymax = self.lhalf_rect[index[0]] lhalf_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['lhalf_mask'] = lhalf_mask if self.opt.exp_eye: results['eye'] = self.eye_area[index].to(self.device) # [1] if self.training: results['eye'] += (np.random.rand()-0.5) / 10 xmin, xmax, ymin, ymax = self.eye_rect[index[0]] eye_mask = (rays['j'] >= xmin) & (rays['j'] < xmax) & (rays['i'] >= ymin) & (rays['i'] < ymax) # [B, N] results['eye_mask'] = eye_mask else: results['eye'] = None # load bg bg_torso_img = self.torso_img[index] if self.preload == 0: # on the fly loading bg_torso_img = cv2.imread(bg_torso_img[0], cv2.IMREAD_UNCHANGED) # [H, W, 4] bg_torso_img = cv2.cvtColor(bg_torso_img, cv2.COLOR_BGRA2RGBA) bg_torso_img = bg_torso_img.astype(np.float32) / 255 # [H, W, 3/4] bg_torso_img = torch.from_numpy(bg_torso_img).unsqueeze(0) bg_torso_img = bg_torso_img[..., :3] * bg_torso_img[..., 3:] + self.bg_img * (1 - bg_torso_img[..., 3:]) bg_torso_img = bg_torso_img.view(B, -1, 3).to(self.device) if not self.opt.torso: bg_img = bg_torso_img else: bg_img = self.bg_img.view(1, -1, 3).repeat(B, 1, 1).to(self.device) if self.training: bg_img = torch.gather(bg_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] results['bg_color'] = bg_img if self.opt.torso and self.training: bg_torso_img = torch.gather(bg_torso_img, 1, torch.stack(3 * [rays['inds']], -1)) # [B, N, 3] results['bg_torso_color'] = bg_torso_img images = self.images[index] # [B, H, W, 3/4] if self.preload == 0: images = cv2.imread(images[0], cv2.IMREAD_UNCHANGED) # [H, W, 3] images = cv2.cvtColor(images, cv2.COLOR_BGR2RGB) images = images.astype(np.float32) / 255 # [H, W, 3] images = torch.from_numpy(images).unsqueeze(0) images = images.to(self.device) if self.training: C = images.shape[-1] images = torch.gather(images.view(B, -1, C), 1, torch.stack(C * [rays['inds']], -1)) # [B, N, 3/4] results['images'] = images if self.training: bg_coords = torch.gather(self.bg_coords, 1, torch.stack(2 * [rays['inds']], -1)) # [1, N, 2] else: bg_coords = self.bg_coords # [1, N, 2] results['bg_coords'] = bg_coords # results['poses'] = convert_poses(poses) # [B, 6] # results['poses_matrix'] = poses # [B, 4, 4] results['poses'] = poses # [B, 4, 4] return results def dataloader(self): if self.training: # training len(poses) == len(auds) size = self.poses.shape[0] else: # test with novel auds, then use its length if self.auds is not None: size = self.auds.shape[0] # live stream test, use 2 * len(poses), so it naturally mirrors. else: size = 2 * self.poses.shape[0] loader = DataLoader(list(range(size)), batch_size=1, collate_fn=self.collate, shuffle=self.training, num_workers=0) loader._data = self # an ugly fix... we need poses in trainer. # do evaluate if has gt images and use self-driven setting loader.has_gt = (self.opt.aud == '') return loader