|
from torch.utils.data import DataLoader |
|
from transformers import DataCollatorWithPadding |
|
|
|
|
|
|
|
|
|
class BooksBatcherIter: |
|
def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024): |
|
self.data_iter = data_iter |
|
self.batch_size = batch_size |
|
self.chunk_size = chunk_size |
|
self.batch_fns = [self._batch_fn()] |
|
self.collate_fn = DataCollatorWithPadding(tokenizer) |
|
|
|
def _batch_fn(self): |
|
for book in self.data_iter: |
|
for i in range(0, len(book), self.chunk_size): |
|
yield book[i:i+self.chunk_size] |
|
|
|
def __iter__(self) -> 'BooksBatcherIter': |
|
return self |
|
|
|
def __next__(self) -> Any: |
|
batch = [] |
|
|
|
try: |
|
for b in self.batch_fns: |
|
batch.append(next(b)) |
|
except StopIteration: |
|
raise StopIteration |
|
|
|
return self.collate_fn(batch) |
|
|
|
|
|
class BooksBatcher: |
|
def __init__(self, dataset, batch_size, tokenizer) -> None: |
|
self.batch_size = batch_size |
|
self.tokenizer = tokenizer |
|
self.dataloader = DataLoader( |
|
dataset=dataset, |
|
batch_size=None, |
|
shuffle=True, |
|
num_workers=2, |
|
prefetch_factor=4 |
|
) |
|
|
|
def __iter__(self) -> 'BooksBatcherIter': |
|
return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer) |
|
|