|
import re |
|
import joblib |
|
import pickle |
|
import numpy as np |
|
import pandas as pd |
|
import tensorflow as tf |
|
from typing import Optional, Union, Tuple |
|
from gensim.models import Word2Vec |
|
from transformers import BertTokenizer |
|
from transformers import BertForSequenceClassification, Trainer, TrainingArguments, BertModel |
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
from torch.nn import MSELoss, CrossEntropyLoss, BCEWithLogitsLoss |
|
from sklearn.model_selection import train_test_split |
|
from sklearn.metrics import accuracy_score, classification_report |
|
import torch.nn.functional as F |
|
|
|
import torch |
|
import time |
|
from torch import nn |
|
from transformers import Trainer |
|
from transformers import AutoModel, AutoTokenizer |
|
from sklearn.metrics import classification_report, confusion_matrix, accuracy_score |
|
|
|
NUM_CLASSES = 3 |
|
DROP_OUT = 0.3 |
|
|
|
class SentimentDataset(torch.utils.data.Dataset): |
|
def __init__(self, encodings, labels=None): |
|
self.encodings = encodings |
|
self.labels = labels |
|
|
|
def __getitem__(self, idx): |
|
item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()} |
|
if self.labels: |
|
item['labels'] = torch.tensor(self.labels[idx]) |
|
return item |
|
|
|
def __len__(self): |
|
return len(self.encodings["input_ids"]) |
|
|
|
class CustomBertForSequenceClassification(BertForSequenceClassification): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.bert = BertModel(config) |
|
classifier_dropout = ( |
|
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout) |
|
|
|
|
|
|
|
|
|
|
|
self.linear_h = nn.Linear(config.hidden_size, 384) |
|
self.linear_o = nn.Linear(384, config.num_labels) |
|
self.selu = nn.SELU() |
|
|
|
print("hidden_size:", config.hidden_size, "num_lables:", config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.Tensor] = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
token_type_ids: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.Tensor] = None, |
|
head_mask: Optional[torch.Tensor] = None, |
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
labels: Optional[torch.Tensor] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
) -> Union[Tuple[torch.Tensor], SequenceClassifierOutput]: |
|
r""" |
|
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*): |
|
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ..., |
|
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If |
|
`config.num_labels > 1` a classification loss is computed (Cross-Entropy). |
|
""" |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
outputs = self.bert( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
head_mask=head_mask, |
|
inputs_embeds=inputs_embeds, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
pooled_output = outputs[1] |
|
|
|
|
|
pooled_output = self.selu(self.linear_h(pooled_output)) |
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
logits = self.linear_o(pooled_output) |
|
|
|
loss = None |
|
if labels is not None: |
|
if self.config.problem_type is None: |
|
if self.num_labels == 1: |
|
self.config.problem_type = "regression" |
|
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
|
self.config.problem_type = "single_label_classification" |
|
else: |
|
self.config.problem_type = "multi_label_classification" |
|
|
|
if self.config.problem_type == "regression": |
|
loss_fct = MSELoss() |
|
if self.num_labels == 1: |
|
loss = loss_fct(logits.squeeze(), labels.squeeze()) |
|
else: |
|
loss = loss_fct(logits, labels) |
|
elif self.config.problem_type == "single_label_classification": |
|
loss_fct = CrossEntropyLoss() |
|
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
|
elif self.config.problem_type == "multi_label_classification": |
|
loss_fct = BCEWithLogitsLoss() |
|
loss = loss_fct(logits, labels) |
|
if not return_dict: |
|
output = (logits,) + outputs[2:] |
|
return ((loss,) + output) if loss is not None else output |
|
|
|
return SequenceClassifierOutput( |
|
loss=loss, |
|
logits=logits, |
|
hidden_states=outputs.hidden_states, |
|
attentions=outputs.attentions, |
|
) |
|
|
|
def train_model(model_name, X_train, X_test, y_train, y_test, epochs=2, train_batch_size=8, eval_batch_size=16, use_emotion_x=False): |
|
|
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
|
|
train_encodings = tokenizer(X_train, truncation=True, padding=True) |
|
train_dataset = SentimentDataset(train_encodings, y_train) |
|
|
|
test_encodings = tokenizer(X_test, truncation=True, padding=True) |
|
test_dataset = SentimentDataset(test_encodings, y_test) |
|
|
|
print(train_dataset[1]['input_ids'].shape) |
|
print(train_dataset[1]['attention_mask'].shape) |
|
|
|
training_args = TrainingArguments( |
|
output_dir='./results', |
|
num_train_epochs=epochs, |
|
per_device_train_batch_size=train_batch_size, |
|
per_device_eval_batch_size=eval_batch_size, |
|
warmup_steps = 500, |
|
weight_decay = 0.01, |
|
logging_dir='./logs', |
|
logging_steps=10, |
|
do_eval=True |
|
) |
|
|
|
if use_emotion_x == True: |
|
model = CustomBertForSequenceClassification.from_pretrained(model_name, num_labels=NUM_CLASSES).to('cuda') |
|
else: |
|
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=NUM_CLASSES).to('cuda') |
|
|
|
trainer = Trainer( |
|
model = model, |
|
args = training_args, |
|
train_dataset = train_dataset, |
|
eval_dataset = test_dataset |
|
) |
|
|
|
s = time.time() |
|
|
|
trainer.train() |
|
|
|
trainer.evaluate(test_dataset) |
|
|
|
prediction = trainer.predict(test_dataset) |
|
|
|
y_logit = torch.tensor(prediction[0]) |
|
|
|
y_pred = F.softmax(y_logit, dim=-1).argmax(axis=1).numpy() |
|
|
|
print(classification_report(y_test, y_pred)) |
|
print(confusion_matrix(y_test, y_pred)) |
|
print(accuracy_score(y_test, y_pred)) |
|
|
|
return trainer , tokenizer |
|
|
|
|
|
def test_trainer(trainer, tokenizer): |
|
POSITIVE = 0 |
|
NEGATIVE = 1 |
|
NEUTRAL = 2 |
|
|
|
idx_to_label = {POSITIVE:'positive', NEGATIVE:'negative', NEUTRAL:'neutral'} |
|
|
|
|
|
test_dict = { |
|
'μ€λ μ§μ¦ μ§λλ‘λ€': NEGATIVE, |
|
'ν΅μ₯μ΄ ν
ν
λΉμμ': NEGATIVE, |
|
'κ²½μ μ¬μ μ΄ μ’ λμμ Έμ μ’λ€μ': POSITIVE, |
|
'κ΅κ°κ° κ΄κ³κ° μ
νλκ³ μμ΄μ': NEGATIVE, |
|
'νκ΅κ³Ό μΌλ³Έμ μ¬μ΄κ° μμ’μμ.': NEGATIVE, |
|
'μ€ν¨λ μ±κ³΅μ μ΄λ¨Έλμ΄λ€.': POSITIVE, |
|
'λ μ¨κ° λ°λ»ν΄μ λ§μμ΄ νΈμν΄μ.': POSITIVE, |
|
'μ£Όλ¨Έλ μ¬μ μ΄ νμ° μ§μ μ' : NEGATIVE, |
|
'λ무 κ±±μ λ§κ³ νλ΄!' : POSITIVE, |
|
'μ μ§μ§! μ§μ¦λκ² κ΅΄μ§λ§κ³ μ 리κ°!' : NEGATIVE, |
|
'μΈμμ΄ νΌκ³€νλ€.' : NEGATIVE, |
|
'λ°λ»ν λ§μ κ°μ¬ν©λλ€.' :POSITIVE, |
|
'λ°λ³΄κ°μ λλ€ νμ¬νλ€' :NEGATIVE, |
|
'κ·Έ λ§μ΄ μ λ₯Ό λ무 νλ€κ² νλ€μ' : NEGATIVE, |
|
'μΈμ§λ§κ³ νλ΄':POSITIVE, |
|
'λλ¬Όμ΄ λ©μΆμ§ μμμ':NEGATIVE, |
|
'μλ‘μ΄ μ¬μ₯λμ μ§μ·¨μ μΈ λΆμ΄λΌ κΈ°λκ° λλ€':POSITIVE, |
|
'μ€λ ν μΌμ΄ νμ°μ΄λ€':NEUTRAL, |
|
'ν μΌμ΄ λ무 λ§μ§λ§ κΎΈμκΎΈμ νκ³ μμ΄':NEUTRAL, |
|
'λ°°κ° κ³ νλ€μ':NEUTRAL, |
|
'μ§μ κ°κ³ μΆλ€μ':NEUTRAL, |
|
'μ½μ½μ νμ νμ€λμ?':NEUTRAL, |
|
'μ»΄ν¨ν° λ°κΏμ£ΌμΈμ.':NEUTRAL, |
|
'νλ λ§μλ?': NEGATIVE, |
|
'μ λλ μ¬νμ μκ°νλ κΈ°λΆμ΄ μ’μ΅λλ€':POSITIVE, |
|
'λ°°κ³ νλ° λ°₯μ΄ μμ΄μ.':NEGATIVE, |
|
'κ΅κ° κ²½μ κ° νν λλ μ€μ΄λ€.':NEGATIVE, |
|
'λλλ¬Έμ λ΄κ° λ무 νλ€μ΄':NEGATIVE, |
|
'κ·Έλλ λκ° μμ΄μ λ€νμ΄μΌ':POSITIVE, |
|
'μμΈν κ²½μ μ¬μ μλ μ΄μ¬ν ν΄μ€μ κ³ λ§μμ':POSITIVE, |
|
'μ€λ κΈ°λΆ μ§±μ΄μμ':POSITIVE, |
|
'λλ λ체 ν μ€ μλκ² λλ?':NEGATIVE, |
|
'μμ κ° λ무 μ΄λ €μ λ―ΈμΉκ² λ€':NEGATIVE, |
|
'μ°λ¦¬ νμλ€ μ΄μ¬ν ν΄μ€μ μλμ€λ½μ΅λλ€':POSITIVE, |
|
'Wow! μν μ§μ§ μ¬λ―Έμλ€':POSITIVE, |
|
'γ
γ
νλ€μ΄ μ£½μκ±° κ°μμ':NEGATIVE, |
|
'μ΄λ² μ¬νμ½μ€λ μ λ§ νμμ μ΄λ€μ':POSITIVE, |
|
'λ΅λ΅ν μν©μ΄μ§λ§ λ μ΄κ²¨λΌ μ μμκΊΌμΌ':POSITIVE, |
|
'λ΅λ΅ν μν©μ΄μ§λ§ λ μ ν΄λΌ μ μμκΊΌμΌ':POSITIVE, |
|
'μΈμ λ κ³μ μμ΄μ€μ νμ΄ λ©λλ€.':POSITIVE, |
|
'λͺΈμ΄ λ무 μνμ μΌμ΄ μμ μμ‘νμ':NEGATIVE, |
|
'λ μ λ§ μνλ€ λ¦¬μ€ν!':POSITIVE, |
|
'μ¬νμ§λ§ κ΄μ±¦μ':POSITIVE, |
|
'κ°λΉ‘μΉλ€ μ§μ§':NEGATIVE, |
|
'λΉκ° λ무 λ§μ΄ μμ μ§μ΄ λ λ΄λ €κ°μ΄μ':NEGATIVE, |
|
'νλΉμ΄ μ¨μ¨ν΄μ μ·μ΄ μ λ§λ₯΄λ€μ':POSITIVE, |
|
'AI곡λΆλ μ΄λ ΅μ§λ§ μ¬λ―Έμμ΄μ':POSITIVE, |
|
'λ μ΄μ©λ©΄ μ’λ? νμ¨λ°μ μλμ¨λ€':NEGATIVE, |
|
'λλ체 λ¬΄μ¨ μκ°μΌλ‘ μ΄λ° μ§μ νκ±°μΌ?':NEGATIVE, |
|
'λ―Έμλ λ€μ νλ²':POSITIVE, |
|
'λμ λ§μ κ°μ¬ν©λλ€':POSITIVE, |
|
'λ§λ μλλ μ리 κ·Έλ§νκ³ μ 리κ°':NEGATIVE, |
|
'μ€λ 컀νΌμ± λΆμκΈ° κ΅Ώ':POSITIVE, |
|
'κΈ°λΆ λλΉ μ λλ μκΈ°νκΈ° μ«μ΄':NEGATIVE, |
|
'μ΄ κ·Έλ¦Ό λ무 λ§μμ λ λ€':POSITIVE, |
|
'μ΄μ΄κ° μμ΄μ ν λ§μ΄ μμ΄':NEGATIVE, |
|
'λλ£ μ§μμ΄ ν΄μ¬ μΈμ¬λ₯Ό νλλ° μμΈν λ§μμ΄ λλ€':NEGATIVE, |
|
'νμμ΄ μμ΄λμ΄ κ²ν λ₯Ό μμ²νλλ° λ무 μ’μ μμ΄λμ΄ κ°μ. μ견μ λ¬Όμ΄λ΄μ€μ κ³ λ§μ':POSITIVE, |
|
'μ±κ²©μ΄ μ’μ νμλ€κ³Ό ν¨κ» ν μ μμ΄μ λ€νμ΄μΌ':POSITIVE, |
|
'κΈμμΌλ§ λλ©΄ κΈ°λΆμ΄ μ’μμ Έ':POSITIVE, |
|
'λ²μ¨ μΌμμΌμ΄λΌλ μΆκ·Όν μκ°νλ κΈ λ€μ΄λλ€.':NEGATIVE, |
|
'μ§μ¦λλκΉ μκΈ°νμ§λ§!':NEGATIVE, |
|
'λ무 μ¬μ¬ν΄.':NEUTRAL, |
|
'λλν μ¬λμ΄λ λννλ건 μ¦κ±°μμ':POSITIVE, |
|
'λΉμ μ νμ μλ μΌκ΅΄μ΄μ΄μ λ§λλ©΄ κΈ°λΆμ΄ μ’μμ Έμ':POSITIVE, |
|
'μ°μ€μ΄ λ무 λ°λΆν΄μ ννμ΄ λμμ':NEGATIVE, |
|
'λ§μλ μλΉμ κ° μκ°μ νλ μ λμ':POSITIVE, |
|
'μ΄λ° νλ₯ν κ°μλ₯Ό λ£κ² λμ μκ΄μ
λλ€.':POSITIVE, |
|
'λ§λλ΅κ² λμ λ°κ°μ΅λλ€.':POSITIVE, |
|
'κ·Έ μ¬λλ§ λ§λλ©΄ μ§μ¦μ΄ λμ λ³΄κΈ°κ° μ«μ΄':NEGATIVE, |
|
'μμ΄λ€μ΄ νκΈ°μ°¨κ² λ°μ΄λ
Έλ λͺ¨μ΅μ΄ 보기 μ’μμ':POSITIVE, |
|
'νμ¬ν μλ¦¬μ’ κ·Έλ§ν μ μμ΄μ?':NEGATIVE, |
|
'μκΈ°κ³ μλΉ μ‘λ€!':NEGATIVE, |
|
'ν΄! μλ
κ°μνλ€!':NEUTRAL, |
|
'λ§κ°μ§λ μμ μ리νκ³ μμ΄! γ
γ
':NEGATIVE, |
|
'μ
μμ μμ΄ μλμΌλ‘ λμ¨λ€...':NEGATIVE, |
|
'μ
λ§ μ΄λ©΄ κ±°μ§λ§μ΄ μλμΌλ‘ λμ!':NEGATIVE, |
|
'μ κ±° λ°λ³΄ μλ?':NEGATIVE, |
|
'νλ€λ κ³μ μμ΄μ€μ κ³ λ§μ':POSITIVE, |
|
'μμ΄νκ° μμ«μ μ΄ν κ°μ':NEGATIVE, |
|
'μ λ° λͺ¨μ§λ¦¬ κ°μΌλλΌκ³ ':NEGATIVE, |
|
'μ§μ§λ¦¬ λͺ»λ λ':NEGATIVE, |
|
'μ μΈκ° λλ¬Έμ λ΄κ° μ λͺ
μ λͺ»μ΄κ² κ°μ':NEGATIVE, |
|
'μ μλΌ μ£½μ¬':NEGATIVE, |
|
'λ μ λ§ μ²μ¬κ°μ':POSITIVE, |
|
'λΉμ μ΄ μ’μμ νμ κ³μ μμ΄μ£ΌμΈμ':POSITIVE, |
|
'κΌ΄λ 보기 μ«μΌλ μ© κΊΌμ Έ':NEGATIVE, |
|
'μ μ§μ§ λμλ²λ¦¬κ² λ€':NEGATIVE, |
|
'μκ²¨μ΄ λλ€':NEGATIVE, |
|
'μ λ° λ―ΈμΈμ 보λ μκ΅¬κ° μ νλλ λλμ΄μΌ':POSITIVE, |
|
'μμ€ γ
γ
λ μ©λλ€':NEGATIVE, |
|
'κΉμΉμ§λ§ λ€μ§λ?':NEGATIVE, |
|
'μΈμ λ νμμ΄μμ':POSITIVE, |
|
'μ€ ν¨λ²λ¦¬κ³ μΆλ€ μ§μ§':NEGATIVE, |
|
'μ κΈ°λ§ λ³΄λ©΄ μμμ΄ λμ':POSITIVE, |
|
'νλ μ§ λ³΄λ©΄ μ λ₯μ κ°μ':NEGATIVE, |
|
'μΉμ±μΆ':NEGATIVE, |
|
'μλλ»':NEGATIVE, |
|
'μ΄ λΉ‘λκ°λ¦¬μΌ':NEGATIVE, |
|
'λλκ°λ¦¬ μμ':NEGATIVE, |
|
'λλ μμμ΄λΌκ³ λ³μ λ μλ§κ° λΆμ':NEGATIVE, |
|
'μμ 곡주λμ΄μμ μΆνν΄μ':POSITIVE, |
|
'μ©μ©ν μμλμ΄μμ. μ’μΌμκ² μ΄μ.':POSITIVE, |
|
'μΌμ¨κ΅¬ μ’λ€':POSITIVE, |
|
'νλμ΄ λ¬΄λμ§λ κΈ°λΆμ΄μΌ':NEGATIVE, |
|
'νλμ λλ κΈ°λΆμ΄μΌ':POSITIVE, |
|
'μμμμ νμ΄ν
!':POSITIVE, |
|
'κ°μλΌ':NEGATIVE, |
|
'μμ£Ό λμ΄μ€':POSITIVE, |
|
'λ΅λ μλ μΈκ°λ€':NEGATIVE, |
|
'μ λ§ μ¬κΈ΄ μ λ₯μ μ§λ¨ κ°μ':NEGATIVE, |
|
'λ§λμ λ°κ°μμ. μ λ§ λ―ΈμΈμ΄μλ€μ':POSITIVE, |
|
'λΉμ μ΄ κ·Έλ¦¬μμ. λ³΄κ³ μΆμ΄μ.':POSITIVE, |
|
'λ°λΌλ§ λ΄λ μμμ΄ λμμ':POSITIVE, |
|
'κ° μ΄λ°λ€':NEGATIVE, |
|
'λ μ λλ° γ
γ
γ
':POSITIVE, |
|
'λ무 λ³΄κ³ μΆμμ΄μ. μ΄λ κ² λ§λκ²λμ λ°κ°μ΅λλ€.':POSITIVE, |
|
'μΉκ΅¬μΌ μ¬λν΄':POSITIVE, |
|
'μ΄ λ°λ³΄ μμμ':NEGATIVE, |
|
'μ€λμ λ μ¨κ° μ°Έ μ’λ€μ. κΈ°λΆμ΄ μμΎν΄μ.':POSITIVE, |
|
'μμν΄μ λ°₯μ΄ μλμ΄κ°λ€.': NEGATIVE, |
|
'λ§μμ΄ μΈμ ν΄μ κΈΈμ λμ°λ€':NEGATIVE, |
|
'μ€λμ μΈμ μ΅κ³ μ λ ': POSITIVE, |
|
'μ΄ νλ₯ν μΌμ λμ°Ένκ² λμ μκ΄μ
λλ€.':POSITIVE, |
|
'λ μ§μμ μλ μ€':NEUTRAL, |
|
} |
|
|
|
hit_cnt = 0 |
|
tot_cnt = len(test_dict) |
|
|
|
for x, y in test_dict.items(): |
|
tokenized = tokenizer([x], truncation=True, padding=True) |
|
pred = trainer.predict(SentimentDataset(tokenized)) |
|
|
|
logit = torch.tensor(pred[0]) |
|
result = F.softmax(logit, dim=-1).argmax(1).numpy() |
|
|
|
if result[0] != y: |
|
print(f"ERROR: {x} expected:{idx_to_label[y]} result:{idx_to_label[result[0]]}") |
|
else: |
|
hit_cnt += 1 |
|
|
|
print() |
|
print(f"hit/total: {hit_cnt}/{tot_cnt}, rate: {hit_cnt/tot_cnt}") |