from torch import nn from transformers import BertForSequenceClassification class ERCBCM(nn.Module): def __init__(self): super(ERCBCM, self).__init__() print('>>> ERCBCM Init!') self.encoder = BertForSequenceClassification.from_pretrained('bert-base-uncased') def forward(self, text, label): loss, text_fea = self.encoder(text, labels=label)[:2] return loss, text_fea