lzq49 commited on
Commit
7b693b2
1 Parent(s): 75dec80

Upload camera_proj.py

Browse files
Files changed (1) hide show
  1. camera_proj/camera_proj.py +54 -0
camera_proj/camera_proj.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ import torch.nn as nn
4
+
5
+ from diffusers.models.activations import get_activation
6
+ from diffusers.models.modeling_utils import ModelMixin
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+
9
+ class CameraMatrixEmbedding(ModelMixin, ConfigMixin):
10
+ @register_to_config
11
+ def __init__(
12
+ self,
13
+ in_channels: int,
14
+ camera_embed_dim: int,
15
+ act_fn: str = "silu",
16
+ out_dim: int = None,
17
+ post_act_fn: Optional[str] = None,
18
+ cond_proj_dim=None,
19
+ ):
20
+ super().__init__()
21
+
22
+ self.linear_1 = nn.Linear(in_channels, camera_embed_dim)
23
+
24
+ if cond_proj_dim is not None:
25
+ self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
26
+ else:
27
+ self.cond_proj = None
28
+
29
+ self.act = get_activation(act_fn)
30
+
31
+ if out_dim is not None:
32
+ camera_embed_dim_out = out_dim
33
+ else:
34
+ camera_embed_dim_out = camera_embed_dim
35
+ self.linear_2 = nn.Linear(camera_embed_dim, camera_embed_dim_out)
36
+
37
+ if post_act_fn is None:
38
+ self.post_act = None
39
+ else:
40
+ self.post_act = get_activation(post_act_fn)
41
+
42
+ def forward(self, sample, condition=None):
43
+ if condition is not None:
44
+ sample = sample + self.cond_proj(condition)
45
+ sample = self.linear_1(sample)
46
+
47
+ if self.act is not None:
48
+ sample = self.act(sample)
49
+
50
+ sample = self.linear_2(sample)
51
+
52
+ if self.post_act is not None:
53
+ sample = self.post_act(sample)
54
+ return sample