haimasree commited on
Commit
e2c9dca
1 Parent(s): c6a8919

Create pretrained_model_reloaded_th.py

Browse files
Files changed (1) hide show
  1. pretrained_model_reloaded_th.py +70 -0
pretrained_model_reloaded_th.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from functools import reduce
5
+ from torch.autograd import Variable
6
+
7
+ class LambdaBase(nn.Sequential):
8
+ def __init__(self, fn, *args):
9
+ super(LambdaBase, self).__init__(*args)
10
+ self.lambda_func = fn
11
+
12
+ def forward_prepare(self, input):
13
+ output = []
14
+ for module in self._modules.values():
15
+ output.append(module(input))
16
+ return output if output else input
17
+
18
+ class Lambda(LambdaBase):
19
+ def forward(self, input):
20
+ return self.lambda_func(self.forward_prepare(input))
21
+
22
+ class LambdaMap(LambdaBase):
23
+ def forward(self, input):
24
+ return list(map(self.lambda_func,self.forward_prepare(input)))
25
+
26
+ class LambdaReduce(LambdaBase):
27
+ def forward(self, input):
28
+ return reduce(self.lambda_func,self.forward_prepare(input))
29
+
30
+ def get_model(load_weights = True):
31
+ # alphabet seems to be fine:
32
+ """
33
+ https://github.com/davek44/Basset/tree/master/src/dna_io.py#L145-L148
34
+ seq = seq.replace('A','0')
35
+ seq = seq.replace('C','1')
36
+ seq = seq.replace('G','2')
37
+ seq = seq.replace('T','3')
38
+ """
39
+ pretrained_model_reloaded_th = nn.Sequential( # Sequential,
40
+ nn.Conv2d(4,300,(19, 1)),
41
+ nn.BatchNorm2d(300),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d((3, 1),(3, 1)),
44
+ nn.Conv2d(300,200,(11, 1)),
45
+ nn.BatchNorm2d(200),
46
+ nn.ReLU(),
47
+ nn.MaxPool2d((4, 1),(4, 1)),
48
+ nn.Conv2d(200,200,(7, 1)),
49
+ nn.BatchNorm2d(200),
50
+ nn.ReLU(),
51
+ nn.MaxPool2d((4, 1),(4, 1)),
52
+ Lambda(lambda x: x.view(x.size(0),-1)), # Reshape,
53
+ nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), # Linear,
54
+ nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
55
+ nn.ReLU(),
56
+ nn.Dropout(0.3),
57
+ nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), # Linear,
58
+ nn.BatchNorm1d(1000,1e-05,0.1,True),#BatchNorm1d,
59
+ nn.ReLU(),
60
+ nn.Dropout(0.3),
61
+ nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), # Linear,
62
+ nn.Sigmoid(),
63
+ )
64
+ if load_weights:
65
+ sd = torch.load('model_files/pretrained_model_reloaded_th.pth')
66
+ pretrained_model_reloaded_th.load_state_dict(sd)
67
+ return pretrained_model_reloaded_th
68
+
69
+ model = get_model(load_weights = False)
70
+