Upload camera_proj.py
Browse files- camera_proj/camera_proj.py +54 -0
camera_proj/camera_proj.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
|
5 |
+
from diffusers.models.activations import get_activation
|
6 |
+
from diffusers.models.modeling_utils import ModelMixin
|
7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
8 |
+
|
9 |
+
class CameraMatrixEmbedding(ModelMixin, ConfigMixin):
|
10 |
+
@register_to_config
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
in_channels: int,
|
14 |
+
camera_embed_dim: int,
|
15 |
+
act_fn: str = "silu",
|
16 |
+
out_dim: int = None,
|
17 |
+
post_act_fn: Optional[str] = None,
|
18 |
+
cond_proj_dim=None,
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
self.linear_1 = nn.Linear(in_channels, camera_embed_dim)
|
23 |
+
|
24 |
+
if cond_proj_dim is not None:
|
25 |
+
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
26 |
+
else:
|
27 |
+
self.cond_proj = None
|
28 |
+
|
29 |
+
self.act = get_activation(act_fn)
|
30 |
+
|
31 |
+
if out_dim is not None:
|
32 |
+
camera_embed_dim_out = out_dim
|
33 |
+
else:
|
34 |
+
camera_embed_dim_out = camera_embed_dim
|
35 |
+
self.linear_2 = nn.Linear(camera_embed_dim, camera_embed_dim_out)
|
36 |
+
|
37 |
+
if post_act_fn is None:
|
38 |
+
self.post_act = None
|
39 |
+
else:
|
40 |
+
self.post_act = get_activation(post_act_fn)
|
41 |
+
|
42 |
+
def forward(self, sample, condition=None):
|
43 |
+
if condition is not None:
|
44 |
+
sample = sample + self.cond_proj(condition)
|
45 |
+
sample = self.linear_1(sample)
|
46 |
+
|
47 |
+
if self.act is not None:
|
48 |
+
sample = self.act(sample)
|
49 |
+
|
50 |
+
sample = self.linear_2(sample)
|
51 |
+
|
52 |
+
if self.post_act is not None:
|
53 |
+
sample = self.post_act(sample)
|
54 |
+
return sample
|