Spaces:
Runtime error
Runtime error
import random | |
from dataclasses import dataclass, field | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import threestudio | |
from threestudio.models.materials.base import BaseMaterial | |
from threestudio.models.networks import get_encoding, get_mlp | |
from threestudio.utils.ops import dot, get_activation | |
from threestudio.utils.typing import * | |
class NoMaterial(BaseMaterial): | |
class Config(BaseMaterial.Config): | |
n_output_dims: int = 3 | |
color_activation: str = "sigmoid" | |
input_feature_dims: Optional[int] = None | |
mlp_network_config: Optional[dict] = None | |
requires_normal: bool = False | |
cfg: Config | |
def configure(self) -> None: | |
self.use_network = False | |
if ( | |
self.cfg.input_feature_dims is not None | |
and self.cfg.mlp_network_config is not None | |
): | |
self.network = get_mlp( | |
self.cfg.input_feature_dims, | |
self.cfg.n_output_dims, | |
self.cfg.mlp_network_config, | |
) | |
self.use_network = True | |
self.requires_normal = self.cfg.requires_normal | |
def forward( | |
self, features: Float[Tensor, "B ... Nf"], **kwargs | |
) -> Float[Tensor, "B ... Nc"]: | |
if not self.use_network: | |
assert ( | |
features.shape[-1] == self.cfg.n_output_dims | |
), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." | |
color = get_activation(self.cfg.color_activation)(features) | |
else: | |
color = self.network(features.view(-1, features.shape[-1])).view( | |
*features.shape[:-1], self.cfg.n_output_dims | |
) | |
color = get_activation(self.cfg.color_activation)(color) | |
return color | |
def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]: | |
color = self(features, **kwargs).clamp(0, 1) | |
assert color.shape[-1] >= 3, "Output color must have at least 3 channels" | |
if color.shape[-1] > 3: | |
threestudio.warn( | |
"Output color has >3 channels, treating the first 3 as RGB" | |
) | |
return {"albedo": color[..., :3]} | |