import torch.nn as nn from transformers import BertModel class BertClassifier(nn.Module): """Bert Model for Classification Tasks. """ def __init__(self, freeze_bert=False): """ @param bert: a BertModel object @param classifier: a torch.nn.Module classifier @param freeze_bert (bool): Set `False` to fine-tune the BERT model """ super(BertClassifier, self).__init__() # hidden size of BERT, hidden size of our classifier, number of labels D_in, H, D_out = 768, 50, 2 # Instantiate BERT model self.bert = BertModel.from_pretrained('aubmindlab/bert-base-arabertv02') # Instantiate an one-layer feed-forward classifier self.classifier = nn.Sequential( nn.Linear(D_in, H), nn.ReLU(), nn.Dropout(0.1), nn.Linear(H, D_out) ) # Freeze the BERT model if freeze_bert: for param in self.bert.parameters(): param.requires_grad = False def forward(self, input_ids, attention_mask): """ Feed input to BERT and the classifier to compute logits. @param input_ids (torch.Tensor): an input tensor with shape (batch_size, max_length) @param attention_mask (torch.Tensor): a tensor that hold attention mask information with shape (batch_size, max_length) @return logits (torch.Tensor): an output tensor with shape (batch_size, num_labels) """ outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask) # Extract the last hidden state of the token `[CLS]` for classification task and feed them to classifier to compute logits last_hidden_state_cls = outputs[0][:, 0, :] logits = self.classifier(last_hidden_state_cls) return logits