Spaces:
Running
Running
import torch | |
from transformers import AutoModel, AutoTokenizer | |
from underthesea import word_tokenize | |
import __main__ | |
#phobert = AutoModel.from_pretrained("vinai/phobert-base") | |
tokenizer = AutoTokenizer.from_pretrained("./bert/bert_tokenizer") | |
class PhoBertModel(torch.nn.Module): | |
def __init__(self): | |
super(PhoBertModel, self).__init__() | |
self.bert = phobert | |
self.pre_classifier = torch.nn.Linear(self.bert.config.hidden_size, self.bert.config.hidden_size) | |
self.dropout = torch.nn.Dropout(0.1) | |
self.classifier = torch.nn.Linear(self.bert.config.hidden_size, 6) | |
def forward(self, input_ids, attention_mask, token_type_ids): | |
hidden_state, output_1 = self.bert( | |
input_ids = input_ids, | |
attention_mask=attention_mask, | |
return_dict = False | |
) | |
pooler = self.pre_classifier(output_1) | |
activation_1 = torch.nn.Tanh()(pooler) | |
drop = self.dropout(activation_1) | |
output_2 = self.classifier(drop) | |
# activation_2 = torch.nn.Tanh()(output_2) | |
output = torch.nn.Sigmoid()(output_2) | |
return output | |
setattr(__main__, "PhoBertModel", PhoBertModel) | |
def getModel(): | |
model = torch.load('./bert/phoBertModel.pth', map_location=torch.device('cpu')) | |
model.eval() | |
return model | |
model = getModel() | |
def tokenize(data): | |
max_length = 200 | |
for line in data: | |
token = tokenizer.encode_plus( | |
line, | |
max_length=200, | |
add_special_tokens=False, | |
pad_to_max_length=True | |
) | |
ids = torch.tensor([token['input_ids']]) | |
mask = torch.tensor([token['attention_mask']]) | |
token_type_ids = torch.tensor([token['token_type_ids']]) | |
output = { | |
'ids': ids, | |
'mask': mask, | |
'token_type_ids': token_type_ids, | |
} | |
#outputs.append(output) | |
return output | |
def BERT_predict(text): | |
text = word_tokenize(text) | |
text = [text] | |
token = tokenize(text) | |
ids = token['ids'] | |
mask = token['mask'] | |
token_type_ids = token['token_type_ids'] | |
result = model(ids, mask, token_type_ids) | |
print(result) | |
return result.tolist()[0] | |
print(BERT_predict("xin chaof")) |