|
import torch |
|
from torch.utils.data import Dataset, DataLoader |
|
import numpy as np |
|
import pytorch_lightning as pl |
|
import torch.nn as nn |
|
from transformers import BertTokenizerFast as BertTokenizer, AdamW, get_linear_schedule_with_warmup, AutoTokenizer, AutoModel |
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
|
class EurovocDataset(Dataset): |
|
|
|
def __init__( |
|
self, |
|
text: np.array, |
|
labels: np.array, |
|
tokenizer: BertTokenizer, |
|
max_token_len: int = 128 |
|
): |
|
self.tokenizer = tokenizer |
|
self.text = text |
|
self.labels = labels |
|
self.max_token_len = max_token_len |
|
|
|
def __len__(self): |
|
return len(self.labels) |
|
|
|
def __getitem__(self, index: int): |
|
text = self.text[index][0] |
|
labels = self.labels[index] |
|
|
|
encoding = self.tokenizer.encode_plus( |
|
text, |
|
add_special_tokens=True, |
|
max_length=self.max_token_len, |
|
return_token_type_ids=False, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
return dict( |
|
text=text, |
|
input_ids=encoding["input_ids"].flatten(), |
|
attention_mask=encoding["attention_mask"].flatten(), |
|
labels=torch.FloatTensor(labels) |
|
) |
|
|
|
|
|
class EuroVocLongTextDataset(Dataset): |
|
|
|
def __splitter__(text, max_lenght): |
|
l = text.split() |
|
for i in range(0, len(l), max_lenght): |
|
yield l[i:i + max_lenght] |
|
|
|
def __init__( |
|
self, |
|
text: np.array, |
|
labels: np.array, |
|
tokenizer: BertTokenizer, |
|
max_token_len: int = 128 |
|
): |
|
self.tokenizer = tokenizer |
|
self.text = text |
|
self.labels = labels |
|
self.max_token_len = max_token_len |
|
|
|
self.chunks_and_labels = [(c, l) for t, l in zip(self.text, self.labels) for c in self.__splitter__(t)] |
|
|
|
self.encoding = self.tokenizer.batch_encode_plus( |
|
[c for c, _ in self.chunks_and_labels], |
|
add_special_tokens=True, |
|
max_length=self.max_token_len, |
|
return_token_type_ids=False, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
return_tensors='pt', |
|
) |
|
|
|
def __len__(self): |
|
return len(self.chunks_and_labels) |
|
|
|
def __getitem__(self, index: int): |
|
text, labels = self.chunks_and_labels[index] |
|
|
|
return dict( |
|
text=text, |
|
input_ids=self.encoding[index]["input_ids"].flatten(), |
|
attention_mask=self.encoding[index]["attention_mask"].flatten(), |
|
labels=torch.FloatTensor(labels) |
|
) |
|
|
|
|
|
class EurovocDataModule(pl.LightningDataModule): |
|
|
|
def __init__(self, bert_model_name, x_tr, y_tr, x_test, y_test, batch_size=8, max_token_len=512): |
|
super().__init__() |
|
|
|
self.batch_size = batch_size |
|
self.x_tr = x_tr |
|
self.y_tr = y_tr |
|
self.x_test = x_test |
|
self.y_test = y_test |
|
self.tokenizer = AutoTokenizer.from_pretrained(bert_model_name) |
|
self.max_token_len = max_token_len |
|
|
|
def setup(self, stage=None): |
|
self.train_dataset = EurovocDataset( |
|
self.x_tr, |
|
self.y_tr, |
|
self.tokenizer, |
|
self.max_token_len |
|
) |
|
|
|
self.test_dataset = EurovocDataset( |
|
self.x_test, |
|
self.y_test, |
|
self.tokenizer, |
|
self.max_token_len |
|
) |
|
|
|
def train_dataloader(self): |
|
return DataLoader( |
|
self.train_dataset, |
|
batch_size=self.batch_size, |
|
shuffle=True, |
|
num_workers=2 |
|
) |
|
|
|
def val_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=2 |
|
) |
|
|
|
def test_dataloader(self): |
|
return DataLoader( |
|
self.test_dataset, |
|
batch_size=self.batch_size, |
|
num_workers=2 |
|
) |
|
|
|
|
|
class EurovocTagger(pl.LightningModule, PyTorchModelHubMixin): |
|
|
|
def __init__(self, bert_model_name, n_classes, lr=2e-5, eps=1e-8): |
|
super().__init__() |
|
self.bert = AutoModel.from_pretrained(bert_model_name) |
|
self.dropout = nn.Dropout(p=0.2) |
|
self.classifier1 = nn.Linear(self.bert.config.hidden_size, n_classes) |
|
self.criterion = nn.BCELoss() |
|
self.lr = lr |
|
self.eps = eps |
|
|
|
def forward(self, input_ids, attention_mask, labels=None): |
|
output = self.bert(input_ids, attention_mask=attention_mask) |
|
output = self.dropout(output.pooler_output) |
|
output = self.classifier1(output) |
|
output = torch.sigmoid(output) |
|
loss = 0 |
|
if labels is not None: |
|
loss = self.criterion(output, labels) |
|
return loss, output |
|
|
|
def training_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
labels = batch["labels"] |
|
loss, outputs = self(input_ids, attention_mask, labels) |
|
self.log("train_loss", loss, prog_bar=True, logger=True) |
|
return {"loss": loss, "predictions": outputs, "labels": labels} |
|
|
|
def validation_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
labels = batch["labels"] |
|
loss, outputs = self(input_ids, attention_mask, labels) |
|
self.log("val_loss", loss, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def test_step(self, batch, batch_idx): |
|
input_ids = batch["input_ids"] |
|
attention_mask = batch["attention_mask"] |
|
labels = batch["labels"] |
|
loss, outputs = self(input_ids, attention_mask, labels) |
|
self.log("test_loss", loss, prog_bar=True, logger=True) |
|
return loss |
|
|
|
def on_train_epoch_end(self, *args, **kwargs): |
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def configure_optimizers(self): |
|
return torch.optim.AdamW(self.parameters(), lr=self.lr, eps=self.eps) |
|
|