stable-fast-3d / sf3d /models /image_estimator /clip_based_estimator.py
mboss's picture
Initial commit
d945eeb
raw
history blame
5.75 kB
from dataclasses import dataclass, field
from typing import Any, List, Optional
import open_clip
import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor
from torchvision.transforms import Normalize
from sf3d.models.network import get_activation
from sf3d.models.utils import BaseModule
@dataclass
class HeadSpec:
name: str
out_channels: int
n_hidden_layers: int
output_activation: Optional[str] = None
output_bias: float = 0.0
add_to_decoder_features: bool = False
shape: Optional[list[int]] = None
class ClipBasedHeadEstimator(BaseModule):
@dataclass
class Config(BaseModule.Config):
model: str = "ViT-B-32"
pretrain: str = "laion2b_s34b_b79k"
distribution: str = "beta"
# ["mean", "mode", "sample", "sample_mean"]
distribution_eval: str = "mode"
activation: str = "relu"
hidden_features: int = 512
heads: List[HeadSpec] = field(default_factory=lambda: [])
cfg: Config
def configure(self):
self.model, _, self.preprocess = open_clip.create_model_and_transforms(
self.cfg.model, pretrained=self.cfg.pretrain
)
self.model.eval()
# Do not add the weights in self.model to the optimizer
for param in self.model.parameters():
param.requires_grad = False
assert len(self.cfg.heads) > 0
heads = {}
for head in self.cfg.heads:
head_layers = []
for i in range(head.n_hidden_layers):
head_layers += [
nn.Linear(
self.cfg.hidden_features,
self.cfg.hidden_features,
),
self.make_activation(self.cfg.activation),
]
head_layers = [nn.Sequential(*head_layers)]
head_layers += [
nn.Sequential(
nn.Linear(
self.cfg.hidden_features,
self.cfg.hidden_features,
),
self.make_activation(self.cfg.activation),
nn.Linear(self.cfg.hidden_features, 1),
)
for _ in range(2)
]
heads[head.name] = nn.ModuleList(head_layers)
self.heads = nn.ModuleDict(heads)
def make_activation(self, activation):
if activation == "relu":
return nn.ReLU(inplace=True)
elif activation == "silu":
return nn.SiLU(inplace=True)
else:
raise NotImplementedError
def forward(
self,
cond_image: Float[Tensor, "B 1 H W 3"],
sample: bool = True,
) -> dict[str, Any]:
# Run the model
# Resize cond_image to 224
cond_image = nn.functional.interpolate(
cond_image.flatten(0, 1).permute(0, 3, 1, 2),
size=(224, 224),
mode="bilinear",
align_corners=False,
)
cond_image = Normalize(
mean=open_clip.constants.OPENAI_DATASET_MEAN,
std=open_clip.constants.OPENAI_DATASET_STD,
)(cond_image)
image_features = self.model.encode_image(cond_image)
# Run the heads
outputs = {}
for head_dict in self.cfg.heads:
head_name = head_dict.name
shared_head, d1_h, d2_h = self.heads[head_name]
shared_features = shared_head(image_features)
d1, d2 = [head(shared_features).squeeze(-1) for head in [d1_h, d2_h]]
if self.cfg.distribution == "normal":
mean = d1
var = d2
if mean.shape[-1] == 1:
outputs[head_name] = torch.distributions.Normal(
mean + head_dict.output_bias,
torch.nn.functional.softplus(var),
)
else:
outputs[head_name] = torch.distributions.MultivariateNormal(
mean + head_dict.output_bias,
torch.nn.functional.softplus(var).diag_embed(),
)
elif self.cfg.distribution == "beta":
outputs[head_name] = torch.distributions.Beta(
torch.nn.functional.softplus(d1 + head_dict.output_bias),
torch.nn.functional.softplus(d2 + head_dict.output_bias),
)
else:
raise NotImplementedError
if sample:
for head_dict in self.cfg.heads:
head_name = head_dict.name
dist = outputs[head_name]
if self.cfg.distribution_eval == "mean":
out = dist.mean
elif self.cfg.distribution_eval == "mode":
out = dist.mode
elif self.cfg.distribution_eval == "sample_mean":
out = dist.sample([10]).mean(-1)
else:
# use rsample if gradient is needed
out = dist.rsample() if self.training else dist.sample()
outputs[head_name] = get_activation(head_dict.output_activation)(out)
outputs[f"{head_name}_dist"] = dist
for head in self.cfg.heads:
if head.shape:
if not sample:
raise ValueError(
"Cannot reshape non-sampled probabilisitic outputs"
)
outputs[head.name] = outputs[head.name].reshape(*head.shape)
if head.add_to_decoder_features:
outputs[f"decoder_{head.name}"] = outputs[head.name]
del outputs[head.name]
return outputs