import torch import torchvision import numpy as np import math import random import time class Bucketeer: def __init__( self, dataloader, sizes=[(256, 256), (192, 384), (192, 320), (384, 192), (320, 192)], is_infinite=True, epoch=0, ): # Ratios and Sizes : (w h) self.sizes = sizes self.batch_size = dataloader.batch_size self._dataloader = dataloader self.iterator = iter(dataloader) self.sampler = dataloader.sampler self.buckets = {s: [] for s in self.sizes} self.is_infinite = is_infinite self._epoch = epoch def get_available_batch(self): available_size = [] for b in self.buckets: if len(self.buckets[b]) >= self.batch_size: available_size.append(b) if len(available_size) == 0: return None else: b = random.choice(available_size) batch = self.buckets[b][:self.batch_size] self.buckets[b] = self.buckets[b][self.batch_size:] return batch def __next__(self): batch = self.get_available_batch() while batch is None: try: elements = next(self.iterator) except StopIteration: # To make it infinity if self.is_infinite: self._epoch += 1 if hasattr(self._dataloader.sampler, "set_epoch"): self._dataloader.sampler.set_epoch(self._epoch) time.sleep(2) # Prevent possible deadlock during epoch transition self.iterator = iter(self._dataloader) elements = next(self.iterator) else: raise StopIteration for dct in elements: try: img = dct['video'] size = (img.shape[-1], img.shape[-2]) self.buckets[size].append({**{'video': img}, **{k:dct[k] for k in dct if k != 'video'}}) except Exception as e: continue batch = self.get_available_batch() out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} return {k: torch.stack(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} def __iter__(self): return self def __len__(self): return len(self.iterator) class TemporalLengthBucketeer: def __init__( self, dataloader, max_frames=16, epoch=0, ): self.batch_size = dataloader.batch_size self._dataloader = dataloader self.iterator = iter(dataloader) self.buckets = {temp: [] for temp in range(1, max_frames + 1)} self._epoch = epoch def get_available_batch(self): available_size = [] for b in self.buckets: if len(self.buckets[b]) >= self.batch_size: available_size.append(b) if len(available_size) == 0: return None else: b = random.choice(available_size) batch = self.buckets[b][:self.batch_size] self.buckets[b] = self.buckets[b][self.batch_size:] return batch def __next__(self): batch = self.get_available_batch() while batch is None: try: elements = next(self.iterator) except StopIteration: # To make it infinity self._epoch += 1 if hasattr(self._dataloader.sampler, "set_epoch"): self._dataloader.sampler.set_epoch(self._epoch) time.sleep(2) # Prevent possible deadlock during epoch transition self.iterator = iter(self._dataloader) elements = next(self.iterator) for dct in elements: try: video_latent = dct['video'] temp = video_latent.shape[2] self.buckets[temp].append({**{'video': video_latent}, **{k:dct[k] for k in dct if k != 'video'}}) except Exception as e: continue batch = self.get_available_batch() out = {k:[batch[i][k] for i in range(len(batch))] for k in batch[0]} out = {k: torch.cat(o, dim=0) if isinstance(o[0], torch.Tensor) else o for k, o in out.items()} if 'prompt_embed' in out: # Loading the pre-extrcted textual features prompt_embeds = out['prompt_embed'].clone() del out['prompt_embed'] prompt_attention_mask = out['prompt_attention_mask'].clone() del out['prompt_attention_mask'] pooled_prompt_embeds = out['pooled_prompt_embed'].clone() del out['pooled_prompt_embed'] out['text'] = { 'prompt_embeds' : prompt_embeds, 'prompt_attention_mask': prompt_attention_mask, 'pooled_prompt_embeds': pooled_prompt_embeds, } return out def __iter__(self): return self def __len__(self): return len(self.iterator)