from torch.utils.data import DataLoader from transformers import DataCollatorWithPadding # some utils for training class BooksBatcherIter: def __init__(self, data_iter, batch_size, tokenizer, chunk_size=1024): self.data_iter = data_iter self.batch_size = batch_size self.chunk_size = chunk_size self.batch_fns = [self._batch_fn()] self.collate_fn = DataCollatorWithPadding(tokenizer) def _batch_fn(self): for book in self.data_iter: for i in range(0, len(book), self.chunk_size): yield book[i:i+self.chunk_size] def __iter__(self) -> 'BooksBatcherIter': return self def __next__(self) -> Any: batch = [] try: for b in self.batch_fns: batch.append(next(b)) except StopIteration: raise StopIteration return self.collate_fn(batch) class BooksBatcher: def __init__(self, dataset, batch_size, tokenizer) -> None: self.batch_size = batch_size self.tokenizer = tokenizer self.dataloader = DataLoader( dataset=dataset, batch_size=None, # return raw samples shuffle=True, num_workers=2, prefetch_factor=4 ) def __iter__(self) -> 'BooksBatcherIter': return BooksBatcherIter(iter(self.dataloader), self.batch_size, self.tokenizer)