szukevin's picture
upload
7900c16
raw
history blame
6.21 kB
import torch
import torch.nn as nn
from tencentpretrain.utils.misc import *
class RnnEncoder(nn.Module):
"""
RNN encoder.
"""
def __init__(self, args):
super(RnnEncoder, self).__init__()
self.bidirectional = args.bidirectional
if self.bidirectional:
assert args.hidden_size % 2 == 0
self.hidden_size = args.hidden_size // 2
else:
self.hidden_size = args.hidden_size
self.layers_num = args.layers_num
self.rnn = nn.RNN(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True,
bidirectional=self.bidirectional)
self.drop = nn.Dropout(args.dropout)
def forward(self, emb, _):
self.rnn.flatten_parameters()
hidden = self.init_hidden(emb.size(0), emb.device)
output, hidden = self.rnn(emb, hidden)
output = self.drop(output)
return output
def init_hidden(self, batch_size, device):
if self.bidirectional:
return torch.zeros(self.layers_num*2, batch_size, self.hidden_size, device=device)
else:
return torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device)
class LstmEncoder(RnnEncoder):
"""
LSTM encoder.
"""
def __init__(self, args):
super(LstmEncoder, self).__init__(args)
self.rnn = nn.LSTM(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True,
bidirectional=self.bidirectional)
def init_hidden(self, batch_size, device):
if self.bidirectional:
return (torch.zeros(self.layers_num*2, batch_size, self.hidden_size, device=device),
torch.zeros(self.layers_num*2, batch_size, self.hidden_size, device=device))
else:
return (torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device),
torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device))
class GruEncoder(RnnEncoder):
"""
GRU encoder.
"""
def __init__(self, args):
super(GruEncoder, self).__init__(args)
self.rnn = nn.GRU(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True,
bidirectional=self.bidirectional)
class BirnnEncoder(nn.Module):
"""
Bi-directional RNN encoder.
"""
def __init__(self, args):
super(BirnnEncoder, self).__init__()
assert args.hidden_size % 2 == 0
self.hidden_size = args.hidden_size // 2
self.layers_num = args.layers_num
self.rnn_forward = nn.RNN(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)
self.rnn_backward = nn.RNN(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)
self.drop = nn.Dropout(args.dropout)
def forward(self, emb, _):
# Forward.
self.rnn_forward.flatten_parameters()
emb_forward = emb
hidden_forward = self.init_hidden(emb_forward.size(0), emb_forward.device)
output_forward, hidden_forward = self.rnn_forward(emb_forward, hidden_forward)
output_forward = self.drop(output_forward)
# Backward.
self.rnn_backward.flatten_parameters()
emb_backward = flip(emb, 1)
hidden_backward = self.init_hidden(emb_backward.size(0), emb_backward.device)
output_backward, hidden_backward = self.rnn_backward(emb_backward, hidden_backward)
output_backward = self.drop(output_backward)
output_backward = flip(output_backward, 1)
return torch.cat([output_forward, output_backward], 2)
def init_hidden(self, batch_size, device):
return torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device)
class BilstmEncoder(BirnnEncoder):
"""
Bi-directional LSTM encoder.
"""
def __init__(self, args):
super(BilstmEncoder, self).__init__(args)
self.rnn_forward = nn.LSTM(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)
self.rnn_backward = nn.LSTM(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)
def init_hidden(self, batch_size, device):
return (torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device),
torch.zeros(self.layers_num, batch_size, self.hidden_size, device=device))
class BigruEncoder(BirnnEncoder):
"""
Bi-directional GRU encoder.
"""
def __init__(self, args):
super(BigruEncoder, self).__init__(args)
self.rnn_forward = nn.GRU(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)
self.rnn_backward = nn.GRU(input_size=args.emb_size,
hidden_size=self.hidden_size,
num_layers=args.layers_num,
dropout=args.dropout,
batch_first=True)