File size: 5,055 Bytes
910e2ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
import os
import torch
import sys
sys.path.append(os.path.abspath('.'))
import argparse
import datetime
import random
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset
from collections import OrderedDict
from einops import rearrange

import json
import jsonlines
from tqdm import tqdm
from torch.utils.data import DataLoader, DistributedSampler

from trainer_misc import init_distributed_mode
from pyramid_dit import (
    SD3TextEncoderWithMask,
    FluxTextEncoderWithMask,
)


def get_args():
    parser = argparse.ArgumentParser('Pytorch Multi-process script', add_help=False)
    parser.add_argument('--batch_size', default=4, type=int)
    parser.add_argument('--anno_file', type=str, default='', help="The video annotation file")
    parser.add_argument('--model_dtype', default='bf16', type=str, help="The Model Dtype: bf16 or df16")
    parser.add_argument('--model_name', default='pyramid_flux', type=str, help="The Model Architecture Name", choices=["pyramid_flux", "pyramid_mmdit"])
    parser.add_argument('--model_path', default='', type=str, help='The pre-trained weight path')
    return parser.parse_args()


class VideoTextDataset(Dataset):
    def __init__(self, anno_file):
        super().__init__()

        self.annotation = []
        with jsonlines.open(anno_file, 'r') as reader:
            for item in tqdm(reader):
                self.annotation.append(item)   # The item is a dict that has key_name: text, text_fea

    def __getitem__(self, index):
        try:
            anno = self.annotation[index]
            text = anno['text']
            text_fea_path = anno['text_fea']    # The text feature saving path
            text_fea_save_dir = os.path.split(text_fea_path)[0]
            if not os.path.exists(text_fea_save_dir):
                os.makedirs(text_fea_save_dir, exist_ok=True)
            return text, text_fea_path
        except Exception as e:
            print(f'Error with {e}')
            return None, None
    
    def __len__(self):
        return len(self.annotation)


def build_data_loader(args):

    def collate_fn(batch):
        text_list = []
        output_path_list = []
        for text, text_fea_path in batch:
            if text is not None:
                text_list.append(text)
                output_path_list.append(text_fea_path)

        return {'text': text_list, 'output': output_path_list}

    dataset = VideoTextDataset(args.anno_file)
    sampler = DistributedSampler(dataset, num_replicas=args.world_size, rank=args.rank, shuffle=False)
    loader = DataLoader(
        dataset, batch_size=args.batch_size, num_workers=4, pin_memory=True, 
        sampler=sampler, shuffle=False, collate_fn=collate_fn, drop_last=False
    )
    return loader


def build_model(args):
    model_dtype = args.model_dtype
    model_name = args.model_name
    model_path = args.model_path

    if model_dtype == 'bf16':
        torch_dtype = torch.bfloat16
    elif model_dtype == 'fp16':
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    if model_name == "pyramid_flux":
        text_encoder = FluxTextEncoderWithMask(model_path, torch_dtype=torch_dtype)
    elif model_name == "pyramid_mmdit":
        text_encoder = SD3TextEncoderWithMask(model_path, torch_dtype=torch_dtype)
    else:
        raise NotImplementedError

    return text_encoder


def main():
    args = get_args()
    
    init_distributed_mode(args)

    # fix the seed for reproducibility
    seed = 42
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)

    device = torch.device('cuda')
    rank = args.rank

    model = build_model(args)
    model.to(device)

    if args.model_dtype == "bf16":
        torch_dtype = torch.bfloat16 
    elif args.model_dtype == "fp16":
        torch_dtype = torch.float16
    else:
        torch_dtype = torch.float32

    data_loader = build_data_loader(args)
    torch.distributed.barrier()

    task_queue = []

    for sample in tqdm(data_loader):
        texts = sample['text']
        outputs = sample['output']

        with torch.no_grad(), torch.cuda.amp.autocast(enabled=True, dtype=torch_dtype):
            prompt_embeds, prompt_attention_masks, pooled_prompt_embeds = model(texts, device)

            for output_path, prompt_embed, prompt_attention_mask, pooled_prompt_embed in zip(outputs, prompt_embeds, prompt_attention_masks, pooled_prompt_embeds):
                output_dict = {
                    'prompt_embed': prompt_embed.unsqueeze(0).cpu().clone(),
                    'prompt_attention_mask': prompt_attention_mask.unsqueeze(0).cpu().clone(),
                    'pooled_prompt_embed': pooled_prompt_embed.unsqueeze(0).cpu().clone(),
                }
                torch.save(output_dict, output_path)

    torch.distributed.barrier()


if __name__ == '__main__':
    main()