Spaces:
Sleeping
Sleeping
File size: 4,160 Bytes
969d94d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import torch
import torch.nn as nn
from torchcrf import CRF
from transformers.models.roberta.modeling_roberta import RobertaPreTrainedModel
from transformers.models.xlm_roberta.modeling_xlm_roberta import XLMRobertaModel
from .module import IntentClassifier, SlotClassifier
class JointXLMR(RobertaPreTrainedModel):
def __init__(self, config, args, intent_label_lst, slot_label_lst):
super(JointXLMR, self).__init__(config)
self.args = args
self.num_intent_labels = len(intent_label_lst)
self.num_slot_labels = len(slot_label_lst)
self.roberta = XLMRobertaModel(config) # Load pretrained bert
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
self.slot_classifier = SlotClassifier(
config.hidden_size,
self.num_intent_labels,
self.num_slot_labels,
self.args.use_intent_context_concat,
self.args.use_intent_context_attention,
self.args.max_seq_len,
self.args.attention_embedding_size,
args.dropout_rate,
)
if args.use_crf:
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
outputs = self.roberta(
input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids
) # sequence_output, pooled_output, (hidden_states), (attentions)
sequence_output = outputs[0]
pooled_output = outputs[1] # [CLS]
intent_logits = self.intent_classifier(pooled_output)
if not self.args.use_attention_mask:
tmp_attention_mask = None
else:
tmp_attention_mask = attention_mask
if self.args.embedding_type == "hard":
hard_intent_logits = torch.zeros(intent_logits.shape)
for i, sample in enumerate(intent_logits):
max_idx = torch.argmax(sample)
hard_intent_logits[i][max_idx] = 1
slot_logits = self.slot_classifier(sequence_output, hard_intent_logits, tmp_attention_mask)
else:
slot_logits = self.slot_classifier(sequence_output, intent_logits, tmp_attention_mask)
total_loss = 0
# 1. Intent Softmax
if intent_label_ids is not None:
if self.num_intent_labels == 1:
intent_loss_fct = nn.MSELoss()
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
else:
intent_loss_fct = nn.CrossEntropyLoss()
intent_loss = intent_loss_fct(
intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1)
)
total_loss += self.args.intent_loss_coef * intent_loss
# 2. Slot Softmax
if slot_labels_ids is not None:
if self.args.use_crf:
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction="mean")
slot_loss = -1 * slot_loss # negative log-likelihood
else:
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
# Only keep active parts of the loss
if attention_mask is not None:
active_loss = attention_mask.view(-1) == 1
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
active_labels = slot_labels_ids.view(-1)[active_loss]
slot_loss = slot_loss_fct(active_logits, active_labels)
else:
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
total_loss += (1 - self.args.intent_loss_coef) * slot_loss
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
outputs = (total_loss,) + outputs
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
|