XS-dev
按要求创建了文件夹,但是不知道tmd到底是什么文件在这个文件夹里面
a5c42f2
from torch.utils.data import Dataset
import torch
import pandas as pd
def load_data(args, split):
df = pd.read_csv(f"{args.data_root}/{split}.csv")
texts = df['text'].values.tolist()
labels = df['target'].values.tolist()
return texts, labels
class MyDataset(Dataset):
def __init__(self, data, tokenizer, max_length, is_test):
self.tokenizer = tokenizer
self.max_length = max_length
self.texts = data[0]
self.labels = data[1]
self.is_test = is_test
def __len__(self):
"""returns the length of dataframe"""
return len(self.texts)
def __getitem__(self, index):
"""return the input ids, attention masks and target ids"""
text = str(self.texts[index])
source = self.tokenizer.batch_encode_plus(
[text],
max_length=self.max_length,
pad_to_max_length=True,
truncation=True,
padding="max_length",
return_tensors="pt",
)
source_ids = source["input_ids"].squeeze()
source_mask = source["attention_mask"].squeeze()
data_sample = {
"input_ids": source_ids,
"attention_mask": source_mask,
}
if not self.is_test:
label = self.labels[index]
target_ids = torch.tensor(label).squeeze()
data_sample["labels"] = target_ids
return data_sample