szukevin's picture
upload
7900c16
raw
history blame
1.61 kB
import torch.nn as nn
class Model(nn.Module):
"""
Pretraining models consist of three (five) parts:
- embedding
- encoder
- tgt_embedding (optional)
- decoder (optional)
- target
"""
def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target):
super(Model, self).__init__()
self.embedding = embedding
self.encoder = encoder
self.tgt_embedding = tgt_embedding
self.decoder = decoder
self.target = target
if "mlm" in args.target and args.tie_weights:
self.target.mlm.linear_2.weight = self.embedding.word.embedding.weight
elif "lm" in args.target and args.tie_weights and "word" in self.embedding.embedding_name_list:
self.target.lm.output_layer.weight = self.embedding.word.embedding.weight
elif "lm" in args.target and args.tie_weights and "word" in self.tgt_embedding.embedding_name_list:
self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight
if self.decoder is not None and args.share_embedding:
self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight
def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None):
emb = self.embedding(src, seg)
memory_bank = self.encoder(emb, seg)
if self.decoder:
tgt_emb = self.tgt_embedding(tgt_in, tgt_seg)
memory_bank = self.decoder(memory_bank, tgt_emb, (seg, tgt_seg))
loss_info = self.target(memory_bank, tgt, seg)
return loss_info