import os import torch import sys sys.path.append(os.path.abspath('.')) import argparse import datetime import random import torch.nn as nn import torch.nn.functional as F import numpy as np from torch.utils.data import Dataset from collections import OrderedDict from einops import rearrange import json import jsonlines from tqdm import tqdm from torch.utils.data import DataLoader, DistributedSampler from trainer_misc import init_distributed_mode from pyramid_dit import ( SD3TextEncoderWithMask, FluxTextEncoderWithMask, ) def get_args(): parser = argparse.ArgumentParser('Pytorch Multi-process script', add_help=False) parser.add_argument('--batch_size', default=4, type=int) parser.add_argument('--anno_file', type=str, default='', help="The video annotation file") parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16") parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The Model Architecture Name", choices=["pyramid_flux", "pyramid_mmdit"]) parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path') return parser.parse_args() class VideoTextDataset(Dataset): def __init__(self, anno_file): super().__init__() self.annotation = [] with jsonlines.open(anno_file, 'r') as reader: for item in tqdm(reader): self.annotation.append(item) # The item is a dict that has key_name: text, text_fea def __getitem__(self, index): try: anno = self.annotation[index] text = anno['text'] text_fea_path = anno['text_fea'] # The text feature saving path text_fea_save_dir = os.path.split(text_fea_path)[0] if not os.path.exists(text_fea_save_dir): os.makedirs(text_fea_save_dir, exist_ok=True) return text, text_fea_path except Exception as e: print(f'Error with {e}') return None, None def __len__(self): return len(self.annotation) def build_data_loader(args): def collate_fn(batch): text_list = [] output_path_list = [] for text, text_fea_path in batch: if text is not None: text_list.append(text) output_path_list.append(text_fea_path) return {'text': text_list, 'output': output_path_list} dataset = VideoTextDataset(args.anno_file) sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False) loader = DataLoader( dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False ) return loader def build_model(args): model_dtype = args.model_dtype model_name = args.model_name model_path = args.model_path if model_dtype == 'bf16': torch_dtype = torch.bfloat16 elif model_dtype == 'fp16': torch_dtype = torch.float16 else: torch_dtype = torch.float32 if model_name == "pyramid_flux": text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype) elif model_name == "pyramid_mmdit": text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype) else: raise NotImplementedError return text_encoder def main(): args = get_args() init_distributed_mode(args) # fix the seed for reproducibility seed = 42 torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) device = torch.device('cuda') rank = args.rank model = build_model(args) model.to(device) if args.model_dtype == "bf16": torch_dtype = torch.bfloat16 elif args.model_dtype == "fp16": torch_dtype = torch.float16 else: torch_dtype = torch.float32 data_loader = build_data_loader(args) torch.distributed.barrier() task_queue = [] for sample in tqdm(data_loader): texts = sample['text'] outputs = sample['output'] with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype): prompt_embeds, prompt_attention_masks, pooled_prompt_embeds = model(texts, device) for output_path, prompt_embed, prompt_attention_mask, pooled_prompt_embed in zip(outputs, prompt_embeds, prompt_attention_masks, pooled_prompt_embeds): output_dict = { 'prompt_embed': prompt_embed.unsqueeze(0).cpu().clone(), 'prompt_attention_mask': prompt_attention_mask.unsqueeze(0).cpu().clone(), 'pooled_prompt_embed': pooled_prompt_embed.unsqueeze(0).cpu().clone(), } torch.save(output_dict, output_path) torch.distributed.barrier() if __name__ == '__main__': main()