dennistrujillo commited on
Commit
beef824
1 Parent(s): 769a90a

added model.py

Browse files
Files changed (1) hide show
  1. model.py +81 -0
model.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def model_init(m):
4
+ if isinstance(m, torch.nn.Conv2d) or isinstance(m, torch.nn.Linear):
5
+ torch.nn.init.xavier_uniform_(m.weight)
6
+ torch.nn.init.zeros_(m.bias)
7
+
8
+ class NLB(torch.nn.Module):
9
+ def __init__(self, in_ch, relu_a=0.01):
10
+ self.inter_ch = torch.div(in_ch, 2, rounding_mode='floor').item()
11
+ super().__init__()
12
+ self.theta_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \
13
+ kernel_size=1, padding=0)
14
+ self.phi_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \
15
+ kernel_size=1, padding=0)
16
+ self.g_layer = torch.nn.Conv2d(in_channels=in_ch, out_channels=self.inter_ch, \
17
+ kernel_size=1, padding=0)
18
+ self.atten_act = torch.nn.Softmax(dim=-1)
19
+ self.out_cnn = torch.nn.Conv2d(in_channels=self.inter_ch, out_channels=in_ch, \
20
+ kernel_size=1, padding=0)
21
+
22
+ def forward(self, x):
23
+ mbsz, _, h, w = x.size()
24
+
25
+ theta = self.theta_layer(x).view(mbsz, self.inter_ch, -1).permute(0, 2, 1)
26
+ phi = self.phi_layer(x).view(mbsz, self.inter_ch, -1)
27
+ g = self.g_layer(x).view(mbsz, self.inter_ch, -1).permute(0, 2, 1)
28
+
29
+ theta_phi = self.atten_act(torch.matmul(theta, phi))
30
+
31
+ theta_phi_g = torch.matmul(theta_phi, g).permute(0, 2, 1).view(mbsz, self.inter_ch, h, w)
32
+
33
+ _out_tmp = self.out_cnn(theta_phi_g)
34
+ _out_tmp = torch.add(_out_tmp, x)
35
+
36
+ return _out_tmp
37
+
38
+
39
+ class BraggNN(torch.nn.Module):
40
+ def __init__(self, imgsz, fcsz=(64, 32, 16, 8)):
41
+ super().__init__()
42
+ self.cnn_ops = []
43
+ cnn_out_chs = (64, 32, 8)
44
+ cnn_in_chs = (1, ) + cnn_out_chs[:-1]
45
+ fsz = imgsz
46
+ for ic, oc, in zip(cnn_in_chs, cnn_out_chs):
47
+ self.cnn_ops += [
48
+ torch.nn.Conv2d(in_channels=ic, out_channels=oc, kernel_size=3, \
49
+ stride=1, padding=0),
50
+ torch.nn.LeakyReLU(negative_slope=0.01),
51
+ ]
52
+ fsz -= 2
53
+ self.nlb = NLB(in_ch=cnn_out_chs[0])
54
+ self.dense_ops = []
55
+ dense_in_chs = (fsz * fsz * cnn_out_chs[-1], ) + fcsz[:-1]
56
+ for ic, oc in zip(dense_in_chs, fcsz):
57
+ self.dense_ops += [
58
+ torch.nn.Linear(ic, oc),
59
+ torch.nn.LeakyReLU(negative_slope=0.01),
60
+ ]
61
+ # output layer
62
+ self.dense_ops += [torch.nn.Linear(fcsz[-1], 2), ]
63
+
64
+ self.cnn_layers = torch.nn.Sequential(*self.cnn_ops)
65
+ self.dense_layers = torch.nn.Sequential(*self.dense_ops)
66
+
67
+ def forward(self, x):
68
+ _out = x
69
+ for layer in self.cnn_layers[:1]:
70
+ _out = layer(_out)
71
+
72
+ _out = self.nlb(_out)
73
+
74
+ for layer in self.cnn_layers[1:]:
75
+ _out = layer(_out)
76
+
77
+ _out = _out.flatten(start_dim=1)
78
+ for layer in self.dense_layers:
79
+ _out = layer(_out)
80
+
81
+ return _out