|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""LSTM layers module.""" |
|
|
|
from torch import nn |
|
|
|
|
|
class SLSTM(nn.Module): |
|
""" |
|
LSTM without worrying about the hidden state, nor the layout of the data. |
|
Expects input as convolutional layout. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dimension: int, |
|
num_layers: int = 2, |
|
skip: bool = True, |
|
bidirectional: bool = False, |
|
): |
|
super().__init__() |
|
self.bidirectional = bidirectional |
|
self.skip = skip |
|
self.lstm = nn.LSTM( |
|
dimension, dimension, num_layers, bidirectional=bidirectional |
|
) |
|
|
|
def forward(self, x): |
|
x = x.permute(2, 0, 1) |
|
y, _ = self.lstm(x) |
|
if self.bidirectional: |
|
x = x.repeat(1, 1, 2) |
|
if self.skip: |
|
y = y + x |
|
y = y.permute(1, 2, 0) |
|
return y |
|
|