Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
class Model(nn.Module): | |
""" | |
Pretraining models consist of three (five) parts: | |
- embedding | |
- encoder | |
- tgt_embedding (optional) | |
- decoder (optional) | |
- target | |
""" | |
def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target): | |
super(Model, self).__init__() | |
self.embedding = embedding | |
self.encoder = encoder | |
self.tgt_embedding = tgt_embedding | |
self.decoder = decoder | |
self.target = target | |
if "mlm" in args.target and args.tie_weights: | |
self.target.mlm.linear_2.weight = self.embedding.word.embedding.weight | |
elif "lm" in args.target and args.tie_weights and "word" in self.embedding.embedding_name_list: | |
self.target.lm.output_layer.weight = self.embedding.word.embedding.weight | |
elif "lm" in args.target and args.tie_weights and "word" in self.tgt_embedding.embedding_name_list: | |
self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight | |
if self.decoder is not None and args.share_embedding: | |
self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight | |
def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None): | |
emb = self.embedding(src, seg) | |
memory_bank = self.encoder(emb, seg) | |
if self.decoder: | |
tgt_emb = self.tgt_embedding(tgt_in, tgt_seg) | |
memory_bank = self.decoder(memory_bank, tgt_emb, (seg, tgt_seg)) | |
loss_info = self.target(memory_bank, tgt, seg) | |
return loss_info | |