|
import torch |
|
import torch.nn.functional as F |
|
from torch import nn |
|
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence |
|
|
|
|
|
class LstmSeq2SeqEncoder(nn.Module): |
|
def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False): |
|
super(LstmSeq2SeqEncoder, self).__init__() |
|
self.lstm = nn.LSTM(input_size=input_size, |
|
hidden_size=hidden_size, |
|
num_layers=num_layers, |
|
dropout=dropout, |
|
bidirectional=bidirectional, |
|
batch_first=True) |
|
|
|
def forward(self, x, mask, hidden=None): |
|
|
|
lengths = mask.sum(dim=1).cpu() |
|
packed_x = pack_padded_sequence(x, lengths, batch_first=True, enforce_sorted=False) |
|
|
|
|
|
packed_output, hidden = self.lstm(packed_x, hidden) |
|
|
|
|
|
output, _ = pad_packed_sequence(packed_output, batch_first=True) |
|
|
|
return output |
|
|