Jaren's picture
Update README.md
1e3a833
|
raw
history blame
1.64 kB

This model used hfl/chinese-roberta-wwm-ext-large backbone and was trained on SNLI, MNLI, DNLI, KvPI data in Chinese version. Model structures are as follows:

`class RobertaForSequenceClassification(nn.Module): def init(self, tagset_size): super(RobertaForSequenceClassification, self).init() self.tagset_size = tagset_size

    self.roberta_single= AutoModel.from_pretrained(pretrain_model_dir)
    self.single_hidden2tag = RobertaClassificationHead(bert_hidden_dim, tagset_size)

def forward(self, input_ids, input_mask):
    outputs_single = self.roberta_single(input_ids, input_mask, None)
    hidden_states_single = outputs_single[1]#torch.tanh(self.hidden_layer_2(torch.tanh(self.hidden_layer_1(outputs_single[1])))) #(batch, hidden)

    score_single = self.single_hidden2tag(hidden_states_single) #(batch, tag_set)
    return score_single

class RobertaClassificationHead(nn.Module): def init(self, bert_hidden_dim, num_labels): super(RobertaClassificationHead, self).init() self.dense = nn.Linear(bert_hidden_dim, bert_hidden_dim) self.dropout = nn.Dropout(0.1) self.out_proj = nn.Linear(bert_hidden_dim, num_labels)

def forward(self, features):
    x = features#[:, 0, :]  # take <s> token (equiv. to [CLS])
    x = self.dropout(x)
    x = self.dense(x)
    x = torch.tanh(x)
    x = self.dropout(x)
    x = self.out_proj(x)
    return x

model = RobertaForSequenceClassification(num_labels) model.load_state_dict(torch.load(args.model_save_path+'Roberta_large_model.pt', map_location=device))`