File size: 9,063 Bytes
b579854 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
#### modeling.py
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig
import torch
import numpy as np
import math
from .dino_wrapper2 import DinoWrapper
from .transformer import TriplaneTransformer
from .synthesizer_part import TriplaneSynthesizer
class CameraEmbedder(nn.Module):
def __init__(self, raw_dim: int, embed_dim: int):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(raw_dim, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim),
)
def forward(self, x):
return self.mlp(x)
class LRMGeneratorConfig(PretrainedConfig):
model_type = "lrm_generator"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.camera_embed_dim = kwargs.get("camera_embed_dim", 1024)
self.rendering_samples_per_ray = kwargs.get("rendering_samples_per_ray", 128)
self.transformer_dim = kwargs.get("transformer_dim", 1024)
self.transformer_layers = kwargs.get("transformer_layers", 16)
self.transformer_heads = kwargs.get("transformer_heads", 16)
self.triplane_low_res = kwargs.get("triplane_low_res", 32)
self.triplane_high_res = kwargs.get("triplane_high_res", 64)
self.triplane_dim = kwargs.get("triplane_dim", 80)
self.encoder_freeze = kwargs.get("encoder_freeze", False)
self.encoder_model_name = kwargs.get("encoder_model_name", 'facebook/dinov2-base')
self.encoder_feat_dim = kwargs.get("encoder_feat_dim", 768)
class LRMGenerator(PreTrainedModel):
config_class = LRMGeneratorConfig
def __init__(self, config: LRMGeneratorConfig):
super().__init__(config)
self.encoder_feat_dim = config.encoder_feat_dim
self.camera_embed_dim = config.camera_embed_dim
self.encoder = DinoWrapper(
model_name=config.encoder_model_name,
freeze=config.encoder_freeze,
)
self.camera_embedder = CameraEmbedder(
raw_dim=12 + 4, embed_dim=config.camera_embed_dim,
)
self.transformer = TriplaneTransformer(
inner_dim=config.transformer_dim, num_layers=config.transformer_layers, num_heads=config.transformer_heads,
image_feat_dim=config.encoder_feat_dim,
camera_embed_dim=config.camera_embed_dim,
triplane_low_res=config.triplane_low_res, triplane_high_res=config.triplane_high_res, triplane_dim=config.triplane_dim,
)
self.synthesizer = TriplaneSynthesizer(
triplane_dim=config.triplane_dim, samples_per_ray=config.rendering_samples_per_ray,
)
def forward(self, image, camera, export_mesh=False, mesh_size=512, render_size=384, export_video=False, fps=30):
assert image.shape[0] == camera.shape[0], "Batch size mismatch"
N = image.shape[0]
# encode image
image_feats = self.encoder(image)
assert image_feats.shape[-1] == self.encoder_feat_dim, \
f"Feature dimension mismatch: {image_feats.shape[-1]} vs {self.encoder_feat_dim}"
# embed camera
camera_embeddings = self.camera_embedder(camera)
assert camera_embeddings.shape[-1] == self.camera_embed_dim, \
f"Feature dimension mismatch: {camera_embeddings.shape[-1]} vs {self.camera_embed_dim}"
with torch.no_grad():
# transformer generating planes
planes = self.transformer(image_feats, camera_embeddings)
assert planes.shape[0] == N, "Batch size mismatch for planes"
assert planes.shape[1] == 3, "Planes should have 3 channels"
# Generate the mesh
if export_mesh:
import mcubes
import trimesh
grid_out = self.synthesizer.forward_grid(planes=planes, grid_size=mesh_size)
vtx, faces = mcubes.marching_cubes(grid_out['sigma'].float().squeeze(0).squeeze(-1).cpu().numpy(), 1.0)
vtx = vtx / (mesh_size - 1) * 2 - 1
vtx_tensor = torch.tensor(vtx, dtype=torch.float32, device=image.device).unsqueeze(0)
vtx_colors = self.synthesizer.forward_points(planes, vtx_tensor)['rgb'].float().squeeze(0).cpu().numpy()
vtx_colors = (vtx_colors * 255).astype(np.uint8)
mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh_path = "awesome_mesh.obj"
mesh.export(mesh_path, 'obj')
return planes, mesh_path
# Generate video
if export_video:
render_cameras = self._default_render_cameras(batch_size=N).to(image.device)
frames = []
chunk_size = 1 # Adjust chunk size as needed
for i in range(0, render_cameras.shape[1], chunk_size):
frame_chunk = self.synthesizer(
planes,
render_cameras[:, i:i + chunk_size],
render_size,
render_size,
0,
0
)
frames.append(frame_chunk['images_rgb'])
frames = torch.cat(frames, dim=1)
frames = (frames.permute(0, 2, 3, 1).cpu().numpy() * 255).astype(np.uint8)
# Save video
video_path = "awesome_video.mp4"
imageio.mimwrite(video_path, frames, fps=fps)
return planes, video_path
return planes
# Copied from https://github.com/facebookresearch/vfusion3d/blob/main/lrm/cam_utils.py
# and https://github.com/facebookresearch/vfusion3d/blob/main/lrm/inferrer.py
def _default_intrinsics(self):
fx = fy = 384
cx = cy = 256
w = h = 512
intrinsics = torch.tensor([
[fx, fy],
[cx, cy],
[w, h],
], dtype=torch.float32)
return intrinsics
def _default_render_cameras(self, batch_size=1):
M = 160 # Number of views
radius = 1.5
elevation = 0
camera_positions = []
rand_theta = np.random.uniform(0, np.pi / 180)
elevation = math.radians(elevation)
for i in range(M):
theta = 2 * math.pi * i / M + rand_theta
x = radius * math.cos(theta) * math.cos(elevation)
y = radius * math.sin(theta) * math.cos(elevation)
z = radius * math.sin(elevation)
camera_positions.append([x, y, z])
camera_positions = torch.tensor(camera_positions, dtype=torch.float32)
extrinsics = self.center_looking_at_camera_pose(camera_positions)
intrinsics = self._default_intrinsics().unsqueeze(0).repeat(extrinsics.shape[0], 1, 1)
render_cameras = self.build_camera_standard(extrinsics, intrinsics)
return render_cameras.unsqueeze(0).repeat(batch_size, 1, 1)
def center_looking_at_camera_pose(self, camera_position: torch.Tensor, look_at: torch.Tensor = None, up_world: torch.Tensor = None):
if look_at is None:
look_at = torch.tensor([0, 0, 0], dtype=torch.float32)
if up_world is None:
up_world = torch.tensor([0, 0, 1], dtype=torch.float32)
look_at = look_at.unsqueeze(0).repeat(camera_position.shape[0], 1)
up_world = up_world.unsqueeze(0).repeat(camera_position.shape[0], 1)
z_axis = camera_position - look_at
z_axis = z_axis / z_axis.norm(dim=-1, keepdim=True)
x_axis = torch.cross(up_world, z_axis)
x_axis = x_axis / x_axis.norm(dim=-1, keepdim=True)
y_axis = torch.cross(z_axis, x_axis)
y_axis = y_axis / y_axis.norm(dim=-1, keepdim=True)
extrinsics = torch.stack([x_axis, y_axis, z_axis, camera_position], dim=-1)
return extrinsics
def get_normalized_camera_intrinsics(self, intrinsics: torch.Tensor):
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 0, 1]
cx, cy = intrinsics[:, 1, 0], intrinsics[:, 1, 1]
width, height = intrinsics[:, 2, 0], intrinsics[:, 2, 1]
fx, fy = fx / width, fy / height
cx, cy = cx / width, cy / height
return fx, fy, cx, cy
def build_camera_standard(self, RT: torch.Tensor, intrinsics: torch.Tensor):
E = self.compose_extrinsic_RT(RT)
fx, fy, cx, cy = self.get_normalized_camera_intrinsics(intrinsics)
I = torch.stack([
torch.stack([fx, torch.zeros_like(fx), cx], dim=-1),
torch.stack([torch.zeros_like(fy), fy, cy], dim=-1),
torch.tensor([[0, 0, 1]], dtype=torch.float32, device=RT.device).repeat(RT.shape[0], 1),
], dim=1)
return torch.cat([
E.reshape(-1, 16),
I.reshape(-1, 9),
], dim=-1)
def compose_extrinsic_RT(self, RT: torch.Tensor):
return torch.cat([
RT,
torch.tensor([[[0, 0, 0, 1]]], dtype=torch.float32).repeat(RT.shape[0], 1, 1).to(RT.device)
], dim=1)
|