Spaces:
Runtime error
Runtime error
from dataclasses import dataclass | |
import torch | |
import torch.nn.functional as F | |
import threestudio | |
from threestudio.models.background.base import BaseBackground | |
from threestudio.models.geometry.base import BaseImplicitGeometry | |
from threestudio.models.materials.base import BaseMaterial | |
from threestudio.models.renderers.base import VolumeRenderer | |
from threestudio.utils.typing import * | |
class PatchRenderer(VolumeRenderer): | |
class Config(VolumeRenderer.Config): | |
patch_size: int = 128 | |
base_renderer_type: str = "" | |
base_renderer: Optional[VolumeRenderer.Config] = None | |
global_detach: bool = False | |
global_downsample: int = 4 | |
cfg: Config | |
def configure( | |
self, | |
geometry: BaseImplicitGeometry, | |
material: BaseMaterial, | |
background: BaseBackground, | |
) -> None: | |
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)( | |
self.cfg.base_renderer, | |
geometry=geometry, | |
material=material, | |
background=background, | |
) | |
def forward( | |
self, | |
rays_o: Float[Tensor, "B H W 3"], | |
rays_d: Float[Tensor, "B H W 3"], | |
light_positions: Float[Tensor, "B 3"], | |
bg_color: Optional[Tensor] = None, | |
**kwargs | |
) -> Dict[str, Float[Tensor, "..."]]: | |
B, H, W, _ = rays_o.shape | |
if self.base_renderer.training: | |
downsample = self.cfg.global_downsample | |
global_rays_o = torch.nn.functional.interpolate( | |
rays_o.permute(0, 3, 1, 2), | |
(H // downsample, W // downsample), | |
mode="bilinear", | |
).permute(0, 2, 3, 1) | |
global_rays_d = torch.nn.functional.interpolate( | |
rays_d.permute(0, 3, 1, 2), | |
(H // downsample, W // downsample), | |
mode="bilinear", | |
).permute(0, 2, 3, 1) | |
out_global = self.base_renderer( | |
global_rays_o, global_rays_d, light_positions, bg_color, **kwargs | |
) | |
PS = self.cfg.patch_size | |
patch_x = torch.randint(0, W - PS, (1,)).item() | |
patch_y = torch.randint(0, H - PS, (1,)).item() | |
patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS] | |
patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS] | |
out = self.base_renderer( | |
patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs | |
) | |
valid_patch_key = [] | |
for key in out: | |
if torch.is_tensor(out[key]): | |
if len(out[key].shape) == len(out["comp_rgb"].shape): | |
if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape: | |
valid_patch_key.append(key) | |
for key in valid_patch_key: | |
out_global[key] = F.interpolate( | |
out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear" | |
).permute(0, 2, 3, 1) | |
if self.cfg.global_detach: | |
out_global[key] = out_global[key].detach() | |
out_global[key][ | |
:, patch_y : patch_y + PS, patch_x : patch_x + PS | |
] = out[key] | |
out = out_global | |
else: | |
out = self.base_renderer( | |
rays_o, rays_d, light_positions, bg_color, **kwargs | |
) | |
return out | |
def update_step( | |
self, epoch: int, global_step: int, on_load_weights: bool = False | |
) -> None: | |
self.base_renderer.update_step(epoch, global_step, on_load_weights) | |
def train(self, mode=True): | |
return self.base_renderer.train(mode) | |
def eval(self): | |
return self.base_renderer.eval() | |