Spaces:
Runtime error
Runtime error
import random | |
from dataclasses import dataclass, field | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import threestudio | |
from threestudio.models.materials.base import BaseMaterial | |
from threestudio.utils.ops import dot, get_activation | |
from threestudio.utils.typing import * | |
class DiffuseWithPointLightMaterial(BaseMaterial): | |
class Config(BaseMaterial.Config): | |
ambient_light_color: Tuple[float, float, float] = (0.1, 0.1, 0.1) | |
diffuse_light_color: Tuple[float, float, float] = (0.9, 0.9, 0.9) | |
ambient_only_steps: int = 1000 | |
diffuse_prob: float = 0.75 | |
textureless_prob: float = 0.5 | |
albedo_activation: str = "sigmoid" | |
soft_shading: bool = False | |
cfg: Config | |
def configure(self) -> None: | |
self.requires_normal = True | |
self.ambient_light_color: Float[Tensor, "3"] | |
self.register_buffer( | |
"ambient_light_color", | |
torch.as_tensor(self.cfg.ambient_light_color, dtype=torch.float32), | |
) | |
self.diffuse_light_color: Float[Tensor, "3"] | |
self.register_buffer( | |
"diffuse_light_color", | |
torch.as_tensor(self.cfg.diffuse_light_color, dtype=torch.float32), | |
) | |
self.ambient_only = False | |
def forward( | |
self, | |
features: Float[Tensor, "B ... Nf"], | |
positions: Float[Tensor, "B ... 3"], | |
shading_normal: Float[Tensor, "B ... 3"], | |
light_positions: Float[Tensor, "B ... 3"], | |
ambient_ratio: Optional[float] = None, | |
shading: Optional[str] = None, | |
**kwargs, | |
) -> Float[Tensor, "B ... 3"]: | |
albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]) | |
if ambient_ratio is not None: | |
# if ambient ratio is specified, use it | |
diffuse_light_color = (1 - ambient_ratio) * torch.ones_like( | |
self.diffuse_light_color | |
) | |
ambient_light_color = ambient_ratio * torch.ones_like( | |
self.ambient_light_color | |
) | |
elif self.training and self.cfg.soft_shading: | |
# otherwise if in training and soft shading is enabled, random a ambient ratio | |
diffuse_light_color = torch.full_like( | |
self.diffuse_light_color, random.random() | |
) | |
ambient_light_color = 1.0 - diffuse_light_color | |
else: | |
# otherwise use the default fixed values | |
diffuse_light_color = self.diffuse_light_color | |
ambient_light_color = self.ambient_light_color | |
light_directions: Float[Tensor, "B ... 3"] = F.normalize( | |
light_positions - positions, dim=-1 | |
) | |
diffuse_light: Float[Tensor, "B ... 3"] = ( | |
dot(shading_normal, light_directions).clamp(min=0.0) * diffuse_light_color | |
) | |
textureless_color = diffuse_light + ambient_light_color | |
# clamp albedo to [0, 1] to compute shading | |
color = albedo.clamp(0.0, 1.0) * textureless_color | |
if shading is None: | |
if self.training: | |
# adopt the same type of augmentation for the whole batch | |
if self.ambient_only or random.random() > self.cfg.diffuse_prob: | |
shading = "albedo" | |
elif random.random() < self.cfg.textureless_prob: | |
shading = "textureless" | |
else: | |
shading = "diffuse" | |
else: | |
if self.ambient_only: | |
shading = "albedo" | |
else: | |
# return shaded color by default in evaluation | |
shading = "diffuse" | |
# multiply by 0 to prevent checking for unused parameters in DDP | |
if shading == "albedo": | |
return albedo + textureless_color * 0 | |
elif shading == "textureless": | |
return albedo * 0 + textureless_color | |
elif shading == "diffuse": | |
return color | |
else: | |
raise ValueError(f"Unknown shading type {shading}") | |
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
if global_step < self.cfg.ambient_only_steps: | |
self.ambient_only = True | |
else: | |
self.ambient_only = False | |
def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: | |
albedo = get_activation(self.cfg.albedo_activation)(features[..., :3]).clamp( | |
0.0, 1.0 | |
) | |
return {"albedo": albedo} | |