|
|
|
import json |
|
import torch |
|
import random |
|
from datasets import load_dataset |
|
from transformers import GPT2Tokenizer |
|
from torch.utils.data import Dataset, get_worker_info |
|
|
|
|
|
def cycled(itr): |
|
while True: |
|
for itm in itr: |
|
yield itm |
|
|
|
class C4X(Dataset): |
|
|
|
def __init__(self, seq_len=512, split='train'): |
|
self.seq = seq_len |
|
self.ds = load_dataset( |
|
'c4', |
|
name='en', |
|
split=split, |
|
streaming=True, |
|
) |
|
self.tok = GPT2Tokenizer.from_pretrained('gpt2') |
|
self.init = False |
|
|
|
def __len__(self): |
|
return 1_000_000_000 |
|
|
|
def _init(self): |
|
if self.init: |
|
return |
|
wi = get_worker_info() |
|
self.ds = cycled( |
|
self.ds.shuffle( |
|
seed=wi.seed, |
|
buffer_size=10_000, |
|
) |
|
) |
|
self.init = True |
|
|
|
def _get_next(self): |
|
self._init() |
|
obj = next(self.ds)['text'] |
|
tkn = self.tok.encode(obj) |
|
return tkn |
|
|
|
def _get_full(self): |
|
obj = [] |
|
while len(obj) < self.seq: |
|
obj += self._get_next() |
|
obj.append(self.tok.eos_token_id) |
|
s = random.randint(0, len(obj)-self.seq) |
|
return obj[s:s+self.seq] |
|
|
|
def __getitem__(self, _): |
|
return torch.tensor(self._get_full()) |
|
|
|
def decode(self, tkns): |
|
return self.tok.decode(tkns) |
|
|