|
import torch |
|
import torch.nn as nn |
|
|
|
from functools import reduce |
|
from torch.autograd import Variable |
|
|
|
class LambdaBase(nn.Sequential): |
|
def __init__(self, fn, *args): |
|
super(LambdaBase, self).__init__(*args) |
|
self.lambda_func = fn |
|
|
|
def forward_prepare(self, input): |
|
output = [] |
|
for module in self._modules.values(): |
|
output.append(module(input)) |
|
return output if output else input |
|
|
|
class Lambda(LambdaBase): |
|
def forward(self, input): |
|
return self.lambda_func(self.forward_prepare(input)) |
|
|
|
class LambdaMap(LambdaBase): |
|
def forward(self, input): |
|
return list(map(self.lambda_func,self.forward_prepare(input))) |
|
|
|
class LambdaReduce(LambdaBase): |
|
def forward(self, input): |
|
return reduce(self.lambda_func,self.forward_prepare(input)) |
|
|
|
def get_model(load_weights = True): |
|
|
|
""" |
|
https://github.com/davek44/Basset/tree/master/src/dna_io.py#L145-L148 |
|
seq = seq.replace('A','0') |
|
seq = seq.replace('C','1') |
|
seq = seq.replace('G','2') |
|
seq = seq.replace('T','3') |
|
""" |
|
pretrained_model_reloaded_th = nn.Sequential( |
|
nn.Conv2d(4,300,(19, 1)), |
|
nn.BatchNorm2d(300), |
|
nn.ReLU(), |
|
nn.MaxPool2d((3, 1),(3, 1)), |
|
nn.Conv2d(300,200,(11, 1)), |
|
nn.BatchNorm2d(200), |
|
nn.ReLU(), |
|
nn.MaxPool2d((4, 1),(4, 1)), |
|
nn.Conv2d(200,200,(7, 1)), |
|
nn.BatchNorm2d(200), |
|
nn.ReLU(), |
|
nn.MaxPool2d((4, 1),(4, 1)), |
|
Lambda(lambda x: x.view(x.size(0),-1)), |
|
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(2000,1000)), |
|
nn.BatchNorm1d(1000,1e-05,0.1,True), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,1000)), |
|
nn.BatchNorm1d(1000,1e-05,0.1,True), |
|
nn.ReLU(), |
|
nn.Dropout(0.3), |
|
nn.Sequential(Lambda(lambda x: x.view(1,-1) if 1==len(x.size()) else x ),nn.Linear(1000,164)), |
|
nn.Sigmoid(), |
|
) |
|
if load_weights: |
|
sd = torch.load('model_files/pretrained_model_reloaded_th.pth') |
|
pretrained_model_reloaded_th.load_state_dict(sd) |
|
return pretrained_model_reloaded_th |
|
|
|
model = get_model(load_weights = False) |
|
|
|
|