|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from transformers import AutoModel, AutoTokenizer, PreTrainedModel |
|
from config import SRLModelConfig |
|
|
|
|
|
class SRLModel(PreTrainedModel): |
|
config_class = SRLModelConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
print(config.num_labels, config.bert_model_name, config.embedding_dropout) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.bert_model_name) |
|
self.transformer = AutoModel.from_pretrained( |
|
config.bert_model_name, |
|
num_labels=config.num_labels, |
|
output_hidden_states=True, |
|
) |
|
self.transformer.config.id2label = config.id2label |
|
self.transformer.config.label2id = config.label2id |
|
|
|
|
|
|
|
|
|
|
|
if "xlm" in config.bert_model_name or "roberta" in config.bert_model_name: |
|
self.transformer.config.type_vocab_size = 2 |
|
|
|
self.transformer.embeddings.token_type_embeddings = nn.Embedding( |
|
2, self.transformer.config.hidden_size |
|
) |
|
|
|
self.transformer.embeddings.token_type_embeddings.weight.data.normal_( |
|
mean=0.0, std=self.transformer.config.initializer_range |
|
) |
|
|
|
|
|
self.tag_projection_layer = nn.Linear( |
|
self.transformer.config.hidden_size, config.num_labels |
|
) |
|
|
|
|
|
self.embedding_dropout = nn.Dropout(p=config.embedding_dropout) |
|
|
|
|
|
self.num_labels = config.num_labels |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids, labels=None): |
|
|
|
|
|
|
|
|
|
|
|
|
|
outputs = self.transformer( |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
) |
|
|
|
|
|
|
|
|
|
bert_embedding = outputs.last_hidden_state |
|
|
|
|
|
embedded_text_input = self.embedding_dropout(bert_embedding) |
|
|
|
|
|
logits = self.tag_projection_layer(embedded_text_input) |
|
|
|
reshaped_log_probs = logits.view(-1, self.num_labels) |
|
class_probabilities = F.softmax(reshaped_log_probs, dim=-1).view( |
|
logits.size(0), logits.size(1), -1 |
|
) |
|
|
|
output_dict = {"logits": logits, "class_probabilities": class_probabilities} |
|
|
|
output_dict["attention_mask"] = attention_mask |
|
output_dict["input_ids"] = input_ids |
|
|
|
|
|
if labels is not None: |
|
|
|
|
|
|
|
|
|
|
|
loss = nn.CrossEntropyLoss(ignore_index=-100)( |
|
logits.view(-1, self.num_labels), labels.view(-1) |
|
) |
|
output_dict["loss"] = loss |
|
return output_dict |
|
|