|
import torch |
|
from transformers import AutoModel, AutoTokenizer |
|
from underthesea import word_tokenize |
|
import __main__ |
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained("./bert/bert_tokenizer") |
|
|
|
class PhoBertModel(torch.nn.Module): |
|
def __init__(self): |
|
super(PhoBertModel, self).__init__() |
|
self.bert = phobert |
|
self.pre_classifier = torch.nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size) |
|
self.dropout = torch.nn.Dropout(0.1) |
|
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 6) |
|
|
|
def forward(self, input_ids, attention_mask, token_type_ids): |
|
hidden_state, output_1 = self.bert( |
|
input_ids = input_ids, |
|
attention_mask=attention_mask, |
|
return_dict = False |
|
) |
|
pooler = self.pre_classifier(output_1) |
|
activation_1 = torch.nn.Tanh()(pooler) |
|
|
|
drop = self.dropout(activation_1) |
|
|
|
output_2 = self.classifier(drop) |
|
|
|
|
|
output = torch.nn.Sigmoid()(output_2) |
|
return output |
|
|
|
setattr(__main__, "PhoBertModel", PhoBertModel) |
|
|
|
def getModel(): |
|
model = torch.load('./bert/phoBertModel.pth', map_location=torch.device('cpu')) |
|
model.eval() |
|
return model |
|
|
|
model = getModel() |
|
|
|
def tokenize(data): |
|
|
|
max_length = 200 |
|
|
|
for line in data: |
|
|
|
token = tokenizer.encode_plus( |
|
line, |
|
max_length=200, |
|
add_special_tokens=False, |
|
pad_to_max_length=True |
|
) |
|
|
|
ids = torch.tensor([token['input_ids']]) |
|
mask = torch.tensor([token['attention_mask']]) |
|
token_type_ids = torch.tensor([token['token_type_ids']]) |
|
|
|
|
|
output = { |
|
'ids': ids, |
|
'mask': mask, |
|
'token_type_ids': token_type_ids, |
|
} |
|
|
|
|
|
return output |
|
|
|
def BERT_predict(text): |
|
|
|
text = [text] |
|
token = tokenize(text) |
|
|
|
ids = token['ids'] |
|
mask = token['mask'] |
|
token_type_ids = token['token_type_ids'] |
|
|
|
result = model(ids, mask, token_type_ids) |
|
|
|
return result.tolist()[0] |
|
|
|
print(BERT_predict("xin chaof")) |
|
print(BERT_predict("con chó")) |
|
print(BERT_predict("đồ chó")) |
|
print(BERT_predict("đồ ngu")) |
|
print(BERT_predict("cái lồn")) |
|
print(BERT_predict("óc chó")) |
|
print(BERT_predict("đồ chó đẻ")) |
|
print(BERT_predict("con đĩ")) |