Spaces:
Build error
Build error
File size: 5,055 Bytes
910e2ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import os
import torch
import sys
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from collections import OrderedDict
from einops import rearrange
import json
import jsonlines
from tqdm import tqdm
from torch.utils.data import DataLoader, DistributedSampler
from trainer_misc import init_distributed_mode
from pyramid_dit import (
SD3TextEncoderWithMask,
FluxTextEncoderWithMask,
)
def get_args():
parser = argparse.ArgumentParser('Pytorch Multi-process script', add_help=False)
parser.add_argument('--batch_size', default=4, type=int)
parser.add_argument('--anno_file', type=str, default='', help="The video annotation file")
parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16")
parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The Model Architecture Name", choices=["pyramid_flux", "pyramid_mmdit"])
parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path')
return parser.parse_args()
class VideoTextDataset(Dataset):
def __init__(self, anno_file):
super().__init__()
self.annotation = []
with jsonlines.open(anno_file, 'r') as reader:
for item in tqdm(reader):
self.annotation.append(item) # The item is a dict that has key_name: text, text_fea
def __getitem__(self, index):
try:
anno = self.annotation[index]
text = anno['text']
text_fea_path = anno['text_fea'] # The text feature saving path
text_fea_save_dir = os.path.split(text_fea_path)[0]
if not os.path.exists(text_fea_save_dir):
os.makedirs(text_fea_save_dir, exist_ok=True)
return text, text_fea_path
except Exception as e:
print(f'Error with {e}')
return None, None
def __len__(self):
return len(self.annotation)
def build_data_loader(args):
def collate_fn(batch):
text_list = []
output_path_list = []
for text, text_fea_path in batch:
if text is not None:
text_list.append(text)
output_path_list.append(text_fea_path)
return {'text': text_list, 'output': output_path_list}
dataset = VideoTextDataset(args.anno_file)
sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
loader = DataLoader(
dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True,
sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False
)
return loader
def build_model(args):
model_dtype = args.model_dtype
model_name = args.model_name
model_path = args.model_path
if model_dtype == 'bf16':
torch_dtype = torch.bfloat16
elif model_dtype == 'fp16':
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
if model_name == "pyramid_flux":
text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype)
elif model_name == "pyramid_mmdit":
text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
else:
raise NotImplementedError
return text_encoder
def main():
args = get_args()
init_distributed_mode(args)
# fix the seed for reproducibility
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
device = torch.device('cuda')
rank = args.rank
model = build_model(args)
model.to(device)
if args.model_dtype == "bf16":
torch_dtype = torch.bfloat16
elif args.model_dtype == "fp16":
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
data_loader = build_data_loader(args)
torch.distributed.barrier()
task_queue = []
for sample in tqdm(data_loader):
texts = sample['text']
outputs = sample['output']
with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
prompt_embeds, prompt_attention_masks, pooled_prompt_embeds = model(texts, device)
for output_path, prompt_embed, prompt_attention_mask, pooled_prompt_embed in zip(outputs, prompt_embeds, prompt_attention_masks, pooled_prompt_embeds):
output_dict = {
'prompt_embed': prompt_embed.unsqueeze(0).cpu().clone(),
'prompt_attention_mask': prompt_attention_mask.unsqueeze(0).cpu().clone(),
'pooled_prompt_embed': pooled_prompt_embed.unsqueeze(0).cpu().clone(),
}
torch.save(output_dict, output_path)
torch.distributed.barrier()
if __name__ == '__main__':
main() |