|
from dataclasses import dataclass, field |
|
from typing import Any, List, Optional |
|
|
|
import open_clip |
|
import torch |
|
import torch.nn as nn |
|
from jaxtyping import Float |
|
from torch import Tensor |
|
from torchvision.transforms import Normalize |
|
|
|
from sf3d.models.network import get_activation |
|
from sf3d.models.utils import BaseModule |
|
|
|
|
|
@dataclass |
|
class HeadSpec: |
|
name: str |
|
out_channels: int |
|
n_hidden_layers: int |
|
output_activation: Optional[str] = None |
|
output_bias: float = 0.0 |
|
add_to_decoder_features: bool = False |
|
shape: Optional[list[int]] = None |
|
|
|
|
|
class ClipBasedHeadEstimator(BaseModule): |
|
@dataclass |
|
class Config(BaseModule.Config): |
|
model: str = "ViT-B-32" |
|
pretrain: str = "laion2b_s34b_b79k" |
|
|
|
distribution: str = "beta" |
|
|
|
|
|
distribution_eval: str = "mode" |
|
|
|
activation: str = "relu" |
|
hidden_features: int = 512 |
|
heads: List[HeadSpec] = field(default_factory=lambda: []) |
|
|
|
cfg: Config |
|
|
|
def configure(self): |
|
self.model, _, self.preprocess = open_clip.create_model_and_transforms( |
|
self.cfg.model, pretrained=self.cfg.pretrain |
|
) |
|
self.model.eval() |
|
|
|
|
|
for param in self.model.parameters(): |
|
param.requires_grad = False |
|
|
|
assert len(self.cfg.heads) > 0 |
|
heads = {} |
|
for head in self.cfg.heads: |
|
head_layers = [] |
|
|
|
for i in range(head.n_hidden_layers): |
|
head_layers += [ |
|
nn.Linear( |
|
self.cfg.hidden_features, |
|
self.cfg.hidden_features, |
|
), |
|
self.make_activation(self.cfg.activation), |
|
] |
|
|
|
head_layers = [nn.Sequential(*head_layers)] |
|
head_layers += [ |
|
nn.Sequential( |
|
nn.Linear( |
|
self.cfg.hidden_features, |
|
self.cfg.hidden_features, |
|
), |
|
self.make_activation(self.cfg.activation), |
|
nn.Linear(self.cfg.hidden_features, 1), |
|
) |
|
for _ in range(2) |
|
] |
|
heads[head.name] = nn.ModuleList(head_layers) |
|
self.heads = nn.ModuleDict(heads) |
|
|
|
def make_activation(self, activation): |
|
if activation == "relu": |
|
return nn.ReLU(inplace=True) |
|
elif activation == "silu": |
|
return nn.SiLU(inplace=True) |
|
else: |
|
raise NotImplementedError |
|
|
|
def forward( |
|
self, |
|
cond_image: Float[Tensor, "B 1 H W 3"], |
|
sample: bool = True, |
|
) -> dict[str, Any]: |
|
|
|
|
|
cond_image = nn.functional.interpolate( |
|
cond_image.flatten(0, 1).permute(0, 3, 1, 2), |
|
size=(224, 224), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
cond_image = Normalize( |
|
mean=open_clip.constants.OPENAI_DATASET_MEAN, |
|
std=open_clip.constants.OPENAI_DATASET_STD, |
|
)(cond_image) |
|
image_features = self.model.encode_image(cond_image) |
|
|
|
|
|
outputs = {} |
|
|
|
for head_dict in self.cfg.heads: |
|
head_name = head_dict.name |
|
shared_head, d1_h, d2_h = self.heads[head_name] |
|
shared_features = shared_head(image_features) |
|
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]] |
|
if self.cfg.distribution == "normal": |
|
mean = d1 |
|
var = d2 |
|
if mean.shape[-1] == 1: |
|
outputs[head_name] = torch.distributions.Normal( |
|
mean + head_dict.output_bias, |
|
torch.nn.functional.softplus(var), |
|
) |
|
else: |
|
outputs[head_name] = torch.distributions.MultivariateNormal( |
|
mean + head_dict.output_bias, |
|
torch.nn.functional.softplus(var).diag_embed(), |
|
) |
|
elif self.cfg.distribution == "beta": |
|
outputs[head_name] = torch.distributions.Beta( |
|
torch.nn.functional.softplus(d1 + head_dict.output_bias), |
|
torch.nn.functional.softplus(d2 + head_dict.output_bias), |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
if sample: |
|
for head_dict in self.cfg.heads: |
|
head_name = head_dict.name |
|
dist = outputs[head_name] |
|
|
|
if self.cfg.distribution_eval == "mean": |
|
out = dist.mean |
|
elif self.cfg.distribution_eval == "mode": |
|
out = dist.mode |
|
elif self.cfg.distribution_eval == "sample_mean": |
|
out = dist.sample([10]).mean(-1) |
|
else: |
|
|
|
out = dist.rsample() if self.training else dist.sample() |
|
|
|
outputs[head_name] = get_activation(head_dict.output_activation)(out) |
|
outputs[f"{head_name}_dist"] = dist |
|
|
|
for head in self.cfg.heads: |
|
if head.shape: |
|
if not sample: |
|
raise ValueError( |
|
"Cannot reshape non-sampled probabilisitic outputs" |
|
) |
|
outputs[head.name] = outputs[head.name].reshape(*head.shape) |
|
|
|
if head.add_to_decoder_features: |
|
outputs[f"decoder_{head.name}"] = outputs[head.name] |
|
del outputs[head.name] |
|
|
|
return outputs |
|
|