import os import random import pickle import torch from tencentpretrain.utils.constants import * from tencentpretrain.utils.tokenizers import * from tencentpretrain.utils.mask import mask_seq from tencentpretrain.utils.augment import SpecAugment class Dataloader(object): def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): self.tokenizer = args.tokenizer self.batch_size = batch_size self.instances_buffer_size = args.instances_buffer_size self.rank = rank self.world_size = world_size self.gpu_id = gpu_id self.shuffle = shuffle self.model_for_dataloader = model_for_dataloader self.dataset_reader = open(dataset_path, "rb") self.read_count = 0 self.start = 0 self.end = 0 self.buffer = [] self.vocab = args.vocab self.whole_word_masking = args.whole_word_masking self.span_masking = args.span_masking self.span_geo_prob = args.span_geo_prob self.span_max_length = args.span_max_length def _fill_buf(self): try: self.buffer = [] while True: instance = pickle.load(self.dataset_reader) self.read_count += 1 if (self.read_count - 1) % self.world_size == self.rank: self.buffer.append(instance) if len(self.buffer) >= self.instances_buffer_size: break except EOFError: # Reach file end. self.dataset_reader.seek(0) if self.shuffle: random.shuffle(self.buffer) self.start = 0 self.end = len(self.buffer) def _empty(self): return self.start >= self.end def __del__(self): self.dataset_reader.close() class BertDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_mlm = [] is_next = [] seg = [] masked_words_num = 0 for ins in instances: src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) if len(ins) == 4: src.append(src_single) masked_words_num += len(ins[1]) tgt_mlm.append([0] * len(src_single)) for mask in ins[1]: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[2]) seg.append([1] * ins[3][0] + [2] * (ins[3][1] - ins[3][0]) + [0] * pad_num) else: src_single, tgt_mlm_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_mlm_single) src.append(src_single) tgt_mlm.append([0] * len(src_single)) for mask in tgt_mlm_single: tgt_mlm[-1][mask[0]] = mask[1] is_next.append(ins[1]) seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * pad_num) if masked_words_num == 0: continue yield torch.LongTensor(src), \ torch.LongTensor(tgt_mlm), \ torch.LongTensor(is_next), \ torch.LongTensor(seg) class MlmDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] masked_words_num = 0 for ins in instances: src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) if len(ins) == 3: src.append(src_single) masked_words_num += len(ins[1]) tgt.append([0] * len(src_single)) for mask in ins[1]: tgt[-1][mask[0]] = mask[1] seg.append([1] * ins[2][0] + [0] * pad_num) else: src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) masked_words_num += len(tgt_single) src.append(src_single) tgt.append([0] * len(src_single)) for mask in tgt_single: tgt[-1][mask[0]] = mask[1] seg.append([1] * ins[1][0] + [0] * pad_num) if masked_words_num == 0: continue yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) class AlbertDataloader(BertDataloader): ''' AlbertDataloader can reuse the code of BertDataloader. ''' pass class LmDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] for ins in instances: src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) src.append(src_single[:-1]) tgt.append(src_single[1:]) seg.append([1] * ins[1][0] + [0] * (len(src_single) - 1 - ins[1][0])) yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) class BilmDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_forward = [] tgt_backward = [] seg = [] for ins in instances: src_single, pad_num = ins[0] tgt_forward_single, tgt_backward_single = ins[1], ins[2] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) tgt_forward_single.append(self.vocab.get(PAD_TOKEN)) tgt_backward_single.append(self.vocab.get(PAD_TOKEN)) src.append(src_single) tgt_forward.append(tgt_forward_single) tgt_backward.append(tgt_backward_single) seg.append([1] * ins[3][0] + [0] * (len(src_single) - ins[3][0])) yield torch.LongTensor(src), \ torch.LongTensor(tgt_forward), \ torch.LongTensor(tgt_backward), \ torch.LongTensor(seg) class MtDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_in = [] tgt_out = [] seg = [] tgt_seg = [] for ins in instances: src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) tgt_single, pad_num = ins[1] for _ in range(pad_num): tgt_single.append(self.vocab.get(PAD_TOKEN)) src.append(src_single) tgt_in.append(tgt_single[:-1]) tgt_out.append(tgt_single[1:]) seg.append([1] * ins[2][0] + [0] * (len(src_single) - ins[2][0])) pad_num = max(ins[1][1] - 1, 0) # left shifted, pad_num >= 0 tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) yield torch.LongTensor(src), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_seg) class T5Dataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_in = [] tgt_out = [] seg = [] tgt_seg = [] tgt_seq_length = 0 for _, ins in enumerate(instances): src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) if len(ins) == 3: tgt_single = ins[1] seg.append([1] * ins[2][0] + [0] * pad_num) else: src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) seg.append([1] * ins[1][0] + [0] * pad_num) MASK_ID = self.vocab.get(MASK_TOKEN) SENTINEL_ID = self.vocab.get(SENTINEL_TOKEN) PAD_ID = self.vocab.get(PAD_TOKEN) for src_index, _ in tgt_single: if src_single[src_index] != MASK_ID: src_single[src_index] = MASK_ID tgt_in_single = [self.vocab.get(CLS_TOKEN)] mask_index = 0 src_with_sentinel = [] for token_id in src_single: if token_id == MASK_ID: if len(src_with_sentinel) > 0 and src_with_sentinel[-1] == (SENTINEL_ID - 1): pass else: src_with_sentinel.append(SENTINEL_ID) tgt_in_single.append(SENTINEL_ID) if SENTINEL_ID < len(self.vocab) - 1: SENTINEL_ID += 1 tgt_in_single.append(tgt_single[mask_index][1]) mask_index += 1 else: src_with_sentinel.append(token_id) tgt_in_single.append(SENTINEL_ID) tgt_in_single.append(self.vocab.get(SEP_TOKEN)) tgt_seg_single = [1] * len(tgt_in_single) while len(src_with_sentinel) < len(src_single): src_with_sentinel.append(PAD_ID) if len(tgt_in_single) > tgt_seq_length: tgt_seq_length = len(tgt_in_single) src.append(src_with_sentinel) tgt_in.append(tgt_in_single) tgt_seg.append(tgt_seg_single) tgt_out.append(tgt_in[-1][1:] + [PAD_ID]) for i in range(len(tgt_in)): while len(tgt_in[i]) != tgt_seq_length: tgt_in[i].append(PAD_ID) tgt_out[i].append(PAD_ID) tgt_seg[i].append(0) yield torch.LongTensor(src), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_seg) class GsgDataloader(MtDataloader): pass class BartDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_in = [] tgt_out = [] seg = [] tgt_seg = [] for _, ins in enumerate(instances): src_single, pad_num = ins[0] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) tgt_single, pad_num = ins[1] for _ in range(pad_num): tgt_single.append(self.vocab.get(PAD_TOKEN)) src_single, _ = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) seg_pos = ins[2][0] tgt_in.append(tgt_single[:-1]) tgt_out.append(tgt_single[1:]) pad_num = max(ins[1][1] - 1, 0) # left shifted, pad_num >= 0 tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) MASK_ID = self.vocab.get(MASK_TOKEN) src_with_span_mask = [] for token_id in src_single: if token_id == MASK_ID: if len(src_with_span_mask) > 0 and src_with_span_mask[-1] == MASK_ID: seg_pos -= 1 else: src_with_span_mask.append(MASK_ID) else: src_with_span_mask.append(token_id) while len(src_with_span_mask) < len(src_single): src_with_span_mask.append(self.vocab.get(PAD_TOKEN)) seg.append([1] * seg_pos + [0] * (len(src_single) - seg_pos)) src.append(src_with_span_mask) yield torch.LongTensor(src), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_seg) class ClsDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] for ins in instances: src_single, pad_num = ins[0] seg_pos_single = ins[2] if len(seg_pos_single) == 1: seg_single = [1] * seg_pos_single[0] elif len(seg_pos_single) == 2: seg_single = [1] * seg_pos_single[0] + [2] * seg_pos_single[1] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) seg_single.append(0) src.append(src_single) tgt.append(ins[1]) seg.append(seg_single) yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) class PrefixlmDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] for ins in instances: src_single, pad_num = ins[0] tgt_single = ins[1] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) tgt_single.append(self.vocab.get(PAD_TOKEN)) src.append(src_single) tgt.append(tgt_single) seg.append([1] * ins[2][0] + [2] * (ins[2][1] - ins[2][0]) + [0] * (len(src_single) - ins[2][1])) yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) class ClsMlmDataloader(Dataloader): def __iter__(self): while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt_mlm = [] tgt_cls = [] seg = [] masked_words_num = 0 for ins in instances: src_single, pad_num = ins[0] seg_pos_single = ins[-1] tgt_cls.append(ins[-2]) if len(seg_pos_single) == 1: seg_single = [1] * seg_pos_single[0] elif len(seg_pos_single) == 2: seg_single = [1] * seg_pos_single[0] + [2] * seg_pos_single[1] for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) seg_single.append(0) seg.append(seg_single) if len(ins) == 4 : src.append(src_single) masked_words_num += len(ins[1]) tgt_mlm.append([0] * len(src_single)) for mask in ins[1]: tgt_mlm[-1][mask[0]] = mask[1] else: src_single, tgt_single = mask_seq(src_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) src.append(src_single) masked_words_num += len(tgt_single) tgt_mlm.append([0] * len(src_single)) for mask in tgt_single: tgt_mlm[-1][mask[0]] = mask[1] if masked_words_num == 0: continue yield torch.LongTensor(src), \ torch.LongTensor(tgt_mlm), \ torch.LongTensor(tgt_cls), \ torch.LongTensor(seg) class VisionDataloader(Dataloader): def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): super(VisionDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) self.patch_size = args.patch_size self.image_height = args.image_height self.image_width = args.image_width from torchvision import transforms from tencentpretrain.utils.misc import ZeroOneNormalize preprocess_pipeline = [] if "corp" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomResizedCrop(max(self.image_height, self.image_width))) if "horizontal_flip" in args.image_preprocess: preprocess_pipeline.append(transforms.RandomHorizontalFlip()) preprocess_pipeline.append(transforms.Resize((self.image_height, self.image_width))) preprocess_pipeline.append(ZeroOneNormalize()) if "normalize" in args.image_preprocess: preprocess_pipeline.append(transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))) self.transform = transforms.Compose(preprocess_pipeline) class VitDataloader(VisionDataloader): def __iter__(self): """ instances: (tgt, image_path) tgt: The category the image belongs to image_path: Path of the image sample Returns: src_image: [batch_size x channel_size x width x hight] seg: [batch_size x (patch_num + 1)] tgt: [batch_size] """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] for ins in instances: image = read_image(ins[1], ImageReadMode.RGB) image = image.cuda(self.gpu_id) src.append(self.transform(image)) tgt.append(ins[0]) seg.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) yield torch.stack(src, 0), \ torch.LongTensor(tgt), \ torch.LongTensor(seg) class ViltDataloader(VisionDataloader): def __iter__(self): """ instances: (src_text, seg_text, image_path) src_text: Tokens of the text sample seg_text: Segment input of text sample src_image: Path of the image sample Returns: src_text: [batch_size x seq_length] src_image: [batch_size x channel_size x width x hight] tgt_mlm: [batch_size x (seq_length + patch_num + 1)] tgt_match: [batch_size] seg: [batch_size x (seq_length + patch_num + 1)] """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src_text = [] src_image = [] tgt_mlm = [] tgt_match = [] seg = [] masked_words_num = 0 for ins in instances: src_text_single, pad_num = ins[0] for _ in range(pad_num): src_text_single.append(self.vocab.get(PAD_TOKEN)) src_text_single, tgt_mlm_single = mask_seq(src_text_single, self.tokenizer, self.whole_word_masking, self.span_masking, self.span_geo_prob, self.span_max_length) src_text.append(src_text_single) masked_words_num += len(tgt_mlm_single) tgt_mlm.append([0] * len(src_text_single)) for mask in tgt_mlm_single: tgt_mlm[-1][mask[0]] = mask[1] if random.random() < 0.5: image = read_image(ins[2], ImageReadMode.RGB) tgt_match.append(1) else: image = read_image(random.choice(self.buffer)[2], ImageReadMode.RGB) tgt_match.append(0) seg_image = [2] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1) tgt_mlm[-1].extend([0] * len(seg_image)) image = image.cuda(self.gpu_id) src_image_single = self.transform(image) src_image.append(src_image_single) seg.append([1] * ins[1][0] + [0] * pad_num + seg_image) if masked_words_num == 0: continue yield torch.LongTensor(src_text), \ torch.stack(src_image, 0), \ torch.LongTensor(tgt_mlm), \ torch.LongTensor(tgt_match), \ torch.LongTensor(seg) class ClipDataloader(VisionDataloader): def __iter__(self): """ instances: (src_text, src_image, seg_text) src_text: Tokens of the text sample src_image: Path of the image sample seg_text: Segment input of text sample Returns: src_text: [batch_size x seq_length] src_image: [batch_size x channel_size x width x hight] seg_text: [batch_size x seq_length] seg_image: [batch_size x (patch_num + 1)] """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src_text = [] src_image = [] seg_text = [] seg_image = [] for ins in instances: src_text_single, pad_num = ins[0] for _ in range(pad_num): src_text_single.append(self.vocab.get(PAD_TOKEN)) src_text.append(src_text_single) seg_text.append([1] * ins[1][0] + [0] * pad_num) image = read_image(ins[2], ImageReadMode.RGB) image = image.cuda(self.gpu_id) src_image.append(self.transform(image)) seg_image.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) yield torch.LongTensor(src_text), \ torch.stack(src_image, 0), \ torch.LongTensor(seg_text), \ torch.LongTensor(seg_image) class AudioDataloader(Dataloader): def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): super(AudioDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) self.dataset_folder = os.path.dirname(dataset_path) self.sampling_rate = args.sampling_rate self.normalize_means, self.normalize_vars, self.ceptral_normalize = True, True, True self.padding_value = 0.0 self.audio_feature_size = args.audio_feature_size self.conv_layers_num = args.conv_layers_num self.max_audio_frames = args.max_audio_frames self.specaugment = None if "normalize_means" not in args.audio_preprocess: self.normalize_means = False if "normalize_vars" not in args.audio_preprocess: self.normalize_vars = False if "ceptral_normalize" not in args.audio_preprocess: self.ceptral_normalize = False if "sepcaugment" in args: self.specaugment = SpecAugment(args) def utterance_cmvn(x, normalize_means=True, normalize_vars=True, gpu_id=None): mean = x.mean(axis=0) square_sums = (x ** 2).sum(axis=0) if normalize_means: x = torch.sub(x, mean) if normalize_vars: var = square_sums / x.size(0) - mean ** 2 if gpu_id is not None: std = torch.sqrt(torch.maximum(var, torch.full(var.size(), 1e-10).cuda(gpu_id))) else: std = torch.sqrt(torch.maximum(var, torch.full(var.size(), 1e-10))) x = torch.div(x, std) return x class S2tDataloader(AudioDataloader): def __iter__(self): import torchaudio import torchaudio.compliance.kaldi as ta_kaldi padding_vector = torch.FloatTensor(self.audio_feature_size * [self.padding_value] if self.audio_feature_size > 1 else self.padding_value).unsqueeze(0).cuda(self.gpu_id) while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size tgt_in = [] tgt_out = [] src_audio = [] seg_audio = [] tgt_seg = [] for ins in instances: text_single, pad_num = ins[0] for _ in range(pad_num): text_single.append(self.vocab.get(PAD_TOKEN)) waveform, _ = torchaudio.load(ins[2]) # waveform, sample_rate waveform = waveform * (2 ** 15) # Kaldi compliance: 16-bit signed integers waveform = waveform.cuda(self.gpu_id) feature = ta_kaldi.fbank(waveform, num_mel_bins=self.audio_feature_size, sample_frequency=self.sampling_rate) if self.ceptral_normalize: feature = utterance_cmvn(feature, self.normalize_means, self.normalize_vars, self.gpu_id) difference = self.max_audio_frames - feature.size(0) if difference < 0: continue else: src_audio.append(torch.cat([feature] + [padding_vector] * difference)) src_pad_num = int(self.max_audio_frames / self.conv_layers_num / 2) - int(feature.size(0) / self.conv_layers_num / 2) seg_audio.append([1] * int(feature.size(0) / self.conv_layers_num / 2) + [0] * src_pad_num) tgt_out.append(text_single[1:]) text_single[-pad_num-1] = self.vocab.get(PAD_TOKEN) tgt_in.append(text_single[:-1]) pad_num = max(pad_num - 1, 0) # left shifted, pad_num >= 0 tgt_seg.append([1] * (len(tgt_in[-1]) - pad_num) + [0] * pad_num) if len(src_audio) == 0: continue if self.specaugment: src_audio = self.specaugment(src_audio) yield torch.stack(src_audio, 0), \ torch.LongTensor(tgt_out), \ torch.LongTensor(seg_audio), \ torch.LongTensor(tgt_in), \ torch.LongTensor(tgt_seg) class BeitDataloader(VisionDataloader): def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): super(BeitDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) from tencentpretrain.utils.image_tokenizer import build_vqgan_model self.vqgan = self.model_for_dataloader def mask(self, image_tokens, mask_rate = 0.15): mask_num = int(len(image_tokens) * mask_rate) mask_index = random.sample(range(1, len(image_tokens)), mask_num) tgt = [0] * len(image_tokens) for idx in mask_index: tgt[idx] = image_tokens[idx] return tgt, mask_index def __iter__(self): """ instances: (tgt, image_path) tgt: The category the image belongs to image_path: Path of the image sample Returns: src_image: [batch_size x channel_size x width x hight] seg: [batch_size x (patch_num + 1)] tgt: [batch_size] """ from torchvision.io import read_image from torchvision.io.image import ImageReadMode from tencentpretrain.utils.image_tokenizer import image_tokenize while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] mask = [] for ins in instances: image = read_image(ins, ImageReadMode.RGB) image = image.cuda(self.gpu_id) image = self.transform(image) src.append(image) image_tokens = [0] + image_tokenize(self.vqgan, image) tgt_single, mask_index = self.mask(image_tokens) tgt.append(tgt_single) mask.append(mask_index) seg.append([1] * ((self.image_height // self.patch_size) * (self.image_width // self.patch_size) + 1)) yield torch.stack(src, 0), \ torch.LongTensor(tgt), \ torch.LongTensor(seg), \ mask class DalleDataloader(VisionDataloader): def __init__(self, args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle=False, model_for_dataloader=None): super(DalleDataloader, self).__init__(args, dataset_path, batch_size, rank, world_size, gpu_id, shuffle, model_for_dataloader) from tencentpretrain.utils.image_tokenizer import build_vqgan_model self.vqgan = self.model_for_dataloader self.vocab_bias = args.tokenizer.vocab_bias def __iter__(self): from torchvision.io import read_image from torchvision.io.image import ImageReadMode from tencentpretrain.utils.image_tokenizer import image_tokenize while True: while self._empty(): self._fill_buf() if self.start + self.batch_size >= self.end: instances = self.buffer[self.start:] else: instances = self.buffer[self.start: self.start + self.batch_size] self.start += self.batch_size src = [] tgt = [] seg = [] for ins in instances: src_single, pad_num = ins[0] image = read_image(ins[2], ImageReadMode.RGB) image = image.cuda(self.gpu_id) image = self.transform(image) image_tokens = [i + self.vocab_bias for i in image_tokenize(self.vqgan, image)] src_single.extend(image_tokens) for _ in range(pad_num): src_single.append(self.vocab.get(PAD_TOKEN)) seg_single = [1] * ins[1][0] + [2] * len(image_tokens) + [0] * pad_num src.append(src_single) tgt.append(src_single[1:] + [self.vocab.get(SEP_TOKEN)]) seg.append(seg_single) yield torch.LongTensor(src), \ torch.LongTensor(tgt), \ torch.LongTensor(seg)