Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
import torch | |
import pandas as pd | |
def load_data(args, split): | |
df = pd.read_csv(f"{args.data_root}/{split}.csv") | |
texts = df['text'].values.tolist() | |
labels = df['target'].values.tolist() | |
return texts, labels | |
class MyDataset(Dataset): | |
def __init__(self, data, tokenizer, max_length, is_test): | |
self.tokenizer = tokenizer | |
self.max_length = max_length | |
self.texts = data[0] | |
self.labels = data[1] | |
self.is_test = is_test | |
def __len__(self): | |
"""returns the length of dataframe""" | |
return len(self.texts) | |
def __getitem__(self, index): | |
"""return the input ids, attention masks and target ids""" | |
text = str(self.texts[index]) | |
source = self.tokenizer.batch_encode_plus( | |
[text], | |
max_length=self.max_length, | |
pad_to_max_length=True, | |
truncation=True, | |
padding="max_length", | |
return_tensors="pt", | |
) | |
source_ids = source["input_ids"].squeeze() | |
source_mask = source["attention_mask"].squeeze() | |
data_sample = { | |
"input_ids": source_ids, | |
"attention_mask": source_mask, | |
} | |
if not self.is_test: | |
label = self.labels[index] | |
target_ids = torch.tensor(label).squeeze() | |
data_sample["labels"] = target_ids | |
return data_sample | |