thewhole's picture
Upload 245 files
2fa4776
raw
history blame
1.78 kB
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("textured-background")
class TexturedBackground(BaseBackground):
@dataclass
class Config(BaseBackground.Config):
n_output_dims: int = 3
height: int = 64
width: int = 64
color_activation: str = "sigmoid"
cfg: Config
def configure(self) -> None:
self.texture = nn.Parameter(
torch.randn((1, self.cfg.n_output_dims, self.cfg.height, self.cfg.width))
)
def spherical_xyz_to_uv(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B 2"]:
x, y, z = dirs[..., 0], dirs[..., 1], dirs[..., 2]
xy = (x**2 + y**2) ** 0.5
u = torch.atan2(xy, z) / torch.pi
v = torch.atan2(y, x) / (torch.pi * 2) + 0.5
uv = torch.stack([u, v], -1)
return uv
def forward(self, dirs: Float[Tensor, "*B 3"]) -> Float[Tensor, "*B Nc"]:
dirs_shape = dirs.shape[:-1]
uv = self.spherical_xyz_to_uv(dirs.reshape(-1, dirs.shape[-1]))
uv = 2 * uv - 1 # rescale to [-1, 1] for grid_sample
uv = uv.reshape(1, -1, 1, 2)
color = (
F.grid_sample(
self.texture,
uv,
mode="bilinear",
padding_mode="reflection",
align_corners=False,
)
.reshape(self.cfg.n_output_dims, -1)
.T.reshape(*dirs_shape, self.cfg.n_output_dims)
)
color = get_activation(self.cfg.color_activation)(color)
return color