Spaces:
Runtime error
Runtime error
from dataclasses import dataclass, field | |
from typing import Any, List, Optional | |
import open_clip | |
import torch | |
import torch.nn as nn | |
from jaxtyping import Float | |
from torch import Tensor | |
from torchvision.transforms import Normalize | |
from sf3d.models.network import get_activation | |
from sf3d.models.utils import BaseModule | |
class HeadSpec: | |
name: str | |
out_channels: int | |
n_hidden_layers: int | |
output_activation: Optional[str] = None | |
output_bias: float = 0.0 | |
add_to_decoder_features: bool = False | |
shape: Optional[list[int]] = None | |
class ClipBasedHeadEstimator(BaseModule): | |
class Config(BaseModule.Config): | |
model: str = "ViT-B-32" | |
pretrain: str = "laion2b_s34b_b79k" | |
distribution: str = "beta" | |
# ["mean", "mode", "sample", "sample_mean"] | |
distribution_eval: str = "mode" | |
activation: str = "relu" | |
hidden_features: int = 512 | |
heads: List[HeadSpec] = field(default_factory=lambda: []) | |
cfg: Config | |
def configure(self): | |
self.model, _, self.preprocess = open_clip.create_model_and_transforms( | |
self.cfg.model, pretrained=self.cfg.pretrain | |
) | |
self.model.eval() | |
# Do not add the weights in self.model to the optimizer | |
for param in self.model.parameters(): | |
param.requires_grad = False | |
assert len(self.cfg.heads) > 0 | |
heads = {} | |
for head in self.cfg.heads: | |
head_layers = [] | |
for i in range(head.n_hidden_layers): | |
head_layers += [ | |
nn.Linear( | |
self.cfg.hidden_features, | |
self.cfg.hidden_features, | |
), | |
self.make_activation(self.cfg.activation), | |
] | |
head_layers = [nn.Sequential(*head_layers)] | |
head_layers += [ | |
nn.Sequential( | |
nn.Linear( | |
self.cfg.hidden_features, | |
self.cfg.hidden_features, | |
), | |
self.make_activation(self.cfg.activation), | |
nn.Linear(self.cfg.hidden_features, 1), | |
) | |
for _ in range(2) | |
] | |
heads[head.name] = nn.ModuleList(head_layers) | |
self.heads = nn.ModuleDict(heads) | |
def make_activation(self, activation): | |
if activation == "relu": | |
return nn.ReLU(inplace=True) | |
elif activation == "silu": | |
return nn.SiLU(inplace=True) | |
else: | |
raise NotImplementedError | |
def forward( | |
self, | |
cond_image: Float[Tensor, "B 1 H W 3"], | |
sample: bool = True, | |
) -> dict[str, Any]: | |
# Run the model | |
# Resize cond_image to 224 | |
cond_image = nn.functional.interpolate( | |
cond_image.flatten(0, 1).permute(0, 3, 1, 2), | |
size=(224, 224), | |
mode="bilinear", | |
align_corners=False, | |
) | |
cond_image = Normalize( | |
mean=open_clip.constants.OPENAI_DATASET_MEAN, | |
std=open_clip.constants.OPENAI_DATASET_STD, | |
)(cond_image) | |
image_features = self.model.encode_image(cond_image) | |
# Run the heads | |
outputs = {} | |
for head_dict in self.cfg.heads: | |
head_name = head_dict.name | |
shared_head, d1_h, d2_h = self.heads[head_name] | |
shared_features = shared_head(image_features) | |
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]] | |
if self.cfg.distribution == "normal": | |
mean = d1 | |
var = d2 | |
if mean.shape[-1] == 1: | |
outputs[head_name] = torch.distributions.Normal( | |
mean + head_dict.output_bias, | |
torch.nn.functional.softplus(var), | |
) | |
else: | |
outputs[head_name] = torch.distributions.MultivariateNormal( | |
mean + head_dict.output_bias, | |
torch.nn.functional.softplus(var).diag_embed(), | |
) | |
elif self.cfg.distribution == "beta": | |
outputs[head_name] = torch.distributions.Beta( | |
torch.nn.functional.softplus(d1 + head_dict.output_bias), | |
torch.nn.functional.softplus(d2 + head_dict.output_bias), | |
) | |
else: | |
raise NotImplementedError | |
if sample: | |
for head_dict in self.cfg.heads: | |
head_name = head_dict.name | |
dist = outputs[head_name] | |
if self.cfg.distribution_eval == "mean": | |
out = dist.mean | |
elif self.cfg.distribution_eval == "mode": | |
out = dist.mode | |
elif self.cfg.distribution_eval == "sample_mean": | |
out = dist.sample([10]).mean(-1) | |
else: | |
# use rsample if gradient is needed | |
out = dist.rsample() if self.training else dist.sample() | |
outputs[head_name] = get_activation(head_dict.output_activation)(out) | |
outputs[f"{head_name}_dist"] = dist | |
for head in self.cfg.heads: | |
if head.shape: | |
if not sample: | |
raise ValueError( | |
"Cannot reshape non-sampled probabilisitic outputs" | |
) | |
outputs[head.name] = outputs[head.name].reshape(*head.shape) | |
if head.add_to_decoder_features: | |
outputs[f"decoder_{head.name}"] = outputs[head.name] | |
del outputs[head.name] | |
return outputs | |