dnnsdunca commited on
Commit
a5dd61d
1 Parent(s): ac8bb9b

Create Dataset.py

Browse files
Files changed (1) hide show
  1. Dataset.py +31 -0
Dataset.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from transformers import AutoTokenizer
3
+
4
+ class MyDataset:
5
+ def __init__(self, data_file, tokenizer):
6
+ self.data = pd.read_csv(data_file)
7
+ self.tokenizer = tokenizer
8
+
9
+ def __len__(self):
10
+ return len(self.data)
11
+
12
+ def __getitem__(self, idx):
13
+ text = self.data.iloc[idx, 0]
14
+ agents = self.data.iloc[idx, 1]
15
+ actions = self.data.iloc[idx, 2]
16
+
17
+ encoding = self.tokenizer.encode_plus(
18
+ text,
19
+ max_length=512,
20
+ padding='max_length',
21
+ truncation=True,
22
+ return_attention_mask=True,
23
+ return_tensors='pt'
24
+ )
25
+
26
+ return {
27
+ 'input_ids': encoding['input_ids'].flatten(),
28
+ 'attention_mask': encoding['attention_mask'].flatten(),
29
+ 'labels_agents': torch.tensor(agents),
30
+ 'labels_actions': torch.tensor(actions)
31
+ }