AmitIsraeli commited on
Commit
64bf706
1 Parent(s): 3aaab28

Add model and infrance app

Browse files
VARtext_v1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bbaa03cee25cb0abba7ac5d476f6b800b78dda29c6cb2773a11b584022585fcf
3
+ size 1963751390
app.py CHANGED
@@ -1,7 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
 
 
 
 
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from models import VQVAE, build_vae_var
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import AutoTokenizer, SiglipTextModel
6
+ from peft import LoraConfig, get_peft_model
7
+ import random
8
+ from torchvision.transforms import ToPILImage
9
+ import numpy as np
10
+ from moviepy.editor import ImageSequenceClip
11
+ import random
12
  import gradio as gr
13
+ import tempfile
14
+ import os
15
 
16
+ class SimpleAdapter(nn.Module):
17
+ def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024):
18
+ super(SimpleAdapter, self).__init__()
19
+ self.layer1 = nn.Linear(input_dim, hidden_dim)
20
+ self.norm0 = nn.LayerNorm(input_dim)
21
+ self.activation1 = nn.GELU()
22
+ self.layer2 = nn.Linear(hidden_dim, out_dim)
23
+ self.norm2 = nn.LayerNorm(out_dim)
24
+ self._initialize_weights()
25
 
26
+ def _initialize_weights(self):
27
+ for m in self.modules():
28
+ if isinstance(m, nn.Linear):
29
+ nn.init.xavier_uniform_(m.weight, gain=0.001)
30
+ nn.init.zeros_(m.bias)
31
+ elif isinstance(m, nn.LayerNorm):
32
+ nn.init.ones_(m.weight)
33
+ nn.init.zeros_(m.bias)
34
+
35
+ def forward(self, x):
36
+ x = self.norm0(x)
37
+ x = self.layer1(x)
38
+ x = self.activation1(x)
39
+ x = self.layer2(x)
40
+ x = self.norm2(x)
41
+ return x
42
+
43
+ class InrenceTextVAR(nn.Module):
44
+ def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16):
45
+ super(InrenceTextVAR, self).__init__()
46
+ self.device = device
47
+ self.class_id = start_class_id
48
+ # Define layers
49
+ patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
50
+ self.vae, self.var = build_vae_var(
51
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
52
+ device=device, patch_nums=patch_nums,
53
+ num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
54
+ )
55
+ self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token)
56
+ self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device)
57
+ self.adapter = SimpleAdapter(
58
+ input_dim=self.siglip_text_encoder.config.hidden_size,
59
+ out_dim=self.var.C # Ensure dimensional consistency
60
+ ).to(device)
61
+ self.apply_lora_to_var()
62
+ if pl_checkpoint is not None:
63
+ state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict']
64
+ var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')}
65
+ vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')}
66
+ adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')}
67
+ self.var.load_state_dict(var_state_dict)
68
+ self.vae.load_state_dict(vae_state_dict)
69
+ self.adapter.load_state_dict(adapter_state_dict)
70
+ del self.vae.encoder
71
+
72
+ def apply_lora_to_var(self):
73
+ """
74
+ Applies LoRA (Low-Rank Adaptation) to the VAR model.
75
+ """
76
+ def find_linear_module_names(model):
77
+ linear_module_names = []
78
+ for name, module in model.named_modules():
79
+ if isinstance(module, nn.Linear):
80
+ linear_module_names.append(name)
81
+ return linear_module_names
82
+
83
+ linear_module_names = find_linear_module_names(self.var)
84
+
85
+ lora_config = LoraConfig(
86
+ r=8,
87
+ lora_alpha=32,
88
+ target_modules=linear_module_names,
89
+ lora_dropout=0.05,
90
+ bias="none",
91
+ )
92
+
93
+ self.var = get_peft_model(self.var, lora_config)
94
+
95
+ @torch.no_grad()
96
+ def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.9):
97
+ if seed is None:
98
+ seed = random.randint(0, 2**32 - 1)
99
+ inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device)
100
+ outputs = self.siglip_text_encoder(**inputs)
101
+ pooled_output = outputs.pooler_output # pooled (EOS token) states
102
+ pooled_output = F.normalize(pooled_output, p=2, dim=-1) # Normalize delta condition
103
+ cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device) # Use correct device
104
+ cond_delta = self.adapter(cond_delta)
105
+ cond_delta = F.normalize(cond_delta, p=2, dim=-1) # Normalize delta condition
106
+ generated_images = self.var.autoregressive_infer_cfg(
107
+ B=1,
108
+ label_B=self.class_id,
109
+ delta_condition=cond_delta[:1],
110
+ beta=beta,
111
+ alpha=1,
112
+ top_k=top_k,
113
+ top_p=top_p,
114
+ more_smooth=more_smooth,
115
+ g_seed=seed
116
+ )
117
+ image = ToPILImage()(generated_images[0].cpu())
118
+ return image
119
+
120
+ @torch.no_grad()
121
+ def generate_video(self, text, start_beta, target_beta, fps, length, top_k=0, top_p=0.9, seed=None,
122
+ more_smooth=False,
123
+ output_filename='output_video.mp4'):
124
+
125
+ if seed is None:
126
+ seed = random.randint(0, 2 ** 32 - 1)
127
+
128
+ num_frames = int(fps * length)
129
+ images = []
130
+
131
+ # Define an easing function for smoother interpolation
132
+ def ease_in_out(t):
133
+ return t * t * (3 - 2 * t)
134
+
135
+ # Generate t values between 0 and 1
136
+ t_values = np.linspace(0, 1, num_frames)
137
+ # Apply the easing function
138
+ eased_t_values = ease_in_out(t_values)
139
+ # Interpolate beta values using the eased t values
140
+ beta_values = start_beta + (target_beta - start_beta) * eased_t_values
141
+
142
+ for beta in beta_values:
143
+ image = self.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
144
+ images.append(np.array(image))
145
+
146
+ # Create a video from images
147
+ clip = ImageSequenceClip(images, fps=fps)
148
+ clip.write_videofile(output_filename, codec='libx264')
149
+
150
+ if __name__ == '__main__':
151
+
152
+ # Initialize the model
153
+ checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
154
+ device = 'cpu' if not torch.cuda.is_available() else 'cuda'
155
+ state_dict = torch.load(checkpoint, map_location="cpu")
156
+ model = InrenceTextVAR(device=device)
157
+ model.load_state_dict(state_dict)
158
+ model.to(device)
159
+
160
+ def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
161
+ print(f"Generating image for text: {text}\n"
162
+ f"beta: {beta}\n"
163
+ f"seed: {seed}\n"
164
+ f"more_smooth: {more_smooth}\n"
165
+ f"top_k: {top_k}\n"
166
+ f"top_p: {top_p}\n")
167
+ image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p)
168
+ return image
169
+
170
+ def generate_video_gradio(text, start_beta=1.0, target_beta=1.0, fps=10, length=5.0, top_k=0, top_p=0.9, seed=None, more_smooth=False, progress=gr.Progress()):
171
+ print(f"Generating video for text: {text}\n"
172
+ f"start_beta: {start_beta}\n"
173
+ f"target_beta: {target_beta}\n"
174
+ f"seed: {seed}\n"
175
+ f"more_smooth: {more_smooth}\n"
176
+ f"top_k: {top_k}\n"
177
+ f"top_p: {top_p}"
178
+ f"fps: {fps}\n"
179
+ f"length: {length}\n")
180
+ with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmpfile:
181
+ output_filename = tmpfile.name
182
+ num_frames = int(fps * length)
183
+ beta_values = np.linspace(start_beta, target_beta, num_frames)
184
+ images = []
185
+
186
+ for i, beta in enumerate(beta_values):
187
+ image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
188
+ images.append(np.array(image))
189
+ # Update progress
190
+ progress((i + 1) / num_frames)
191
+ # Yield the frame image to update the GUI
192
+ yield image, gr.update()
193
+
194
+ # After generating all frames, create the video
195
+ clip = ImageSequenceClip(images, fps=fps)
196
+ clip.write_videofile(output_filename, codec='libx264')
197
+
198
+ # Yield the final video output
199
+ yield gr.update(), output_filename
200
+
201
+ with gr.Blocks() as demo:
202
+ gr.Markdown("# Text to Image/Video Generator")
203
+ with gr.Tab("Generate Image"):
204
+ text_input = gr.Textbox(label="Input Text")
205
+ beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0)
206
+ seed_input = gr.Number(label="Seed", value=None)
207
+ more_smooth_input = gr.Checkbox(label="More Smooth", value=False)
208
+ top_k_input = gr.Number(label="Top K", value=0)
209
+ top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
210
+ generate_button = gr.Button("Generate Image")
211
+ image_output = gr.Image(label="Generated Image")
212
+ generate_button.click(
213
+ generate_image_gradio,
214
+ inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input],
215
+ outputs=image_output
216
+ )
217
+
218
+ with gr.Tab("Generate Video"):
219
+ text_input_video = gr.Textbox(label="Input Text")
220
+ start_beta_input = gr.Slider(label="Start Beta", minimum=0.0, maximum=2.5, step=0.05, value=0)
221
+ target_beta_input = gr.Slider(label="Target Beta",minimum=0.0, maximum=2.5, step=0.05, value=1.0)
222
+ fps_input = gr.Number(label="FPS", value=10)
223
+ length_input = gr.Number(label="Length (seconds)", value=5.0)
224
+ seed_input_video = gr.Number(label="Seed", value=None)
225
+ more_smooth_input_video = gr.Checkbox(label="More Smooth", value=False)
226
+ top_k_input_video = gr.Number(label="Top K", value=0)
227
+ top_p_input_video = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
228
+ generate_video_button = gr.Button("Generate Video")
229
+ frame_output = gr.Image(label="Current Frame")
230
+ video_output = gr.Video(label="Generated Video")
231
+
232
+ generate_video_button.click(
233
+ generate_video_gradio,
234
+ inputs=[text_input_video, start_beta_input, target_beta_input, fps_input, length_input, top_k_input_video, top_p_input_video, seed_input_video, more_smooth_input_video],
235
+ outputs=[frame_output, video_output],
236
+ queue=True # Enable queuing to allow for progress updates
237
+ )
238
+
239
+ demo.launch()
dist.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import os
4
+ import sys
5
+ from typing import List
6
+ from typing import Union
7
+
8
+ import torch
9
+ import torch.distributed as tdist
10
+ import torch.multiprocessing as mp
11
+
12
+ __rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
13
+ __initialized = False
14
+
15
+
16
+ def initialized():
17
+ return __initialized
18
+
19
+
20
+ def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
21
+ global __device
22
+ if not torch.cuda.is_available():
23
+ print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
24
+ return
25
+ elif 'RANK' not in os.environ:
26
+ torch.cuda.set_device(gpu_id_if_not_distibuted)
27
+ __device = torch.empty(1).cuda().device
28
+ print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
29
+ return
30
+ # then 'RANK' must exist
31
+ global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
32
+ local_rank = global_rank % num_gpus
33
+ torch.cuda.set_device(local_rank)
34
+
35
+ # ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
36
+ if mp.get_start_method(allow_none=True) is None:
37
+ method = 'fork' if fork else 'spawn'
38
+ print(f'[dist initialize] mp method={method}')
39
+ mp.set_start_method(method)
40
+ tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
41
+
42
+ global __rank, __local_rank, __world_size, __initialized
43
+ __local_rank = local_rank
44
+ __rank, __world_size = tdist.get_rank(), tdist.get_world_size()
45
+ __device = torch.empty(1).cuda().device
46
+ __initialized = True
47
+
48
+ assert tdist.is_initialized(), 'torch.distributed is not initialized!'
49
+ print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
50
+
51
+
52
+ def get_rank():
53
+ return __rank
54
+
55
+
56
+ def get_local_rank():
57
+ return __local_rank
58
+
59
+
60
+ def get_world_size():
61
+ return __world_size
62
+
63
+
64
+ def get_device():
65
+ return __device
66
+
67
+
68
+ def set_gpu_id(gpu_id: int):
69
+ if gpu_id is None: return
70
+ global __device
71
+ if isinstance(gpu_id, (str, int)):
72
+ torch.cuda.set_device(int(gpu_id))
73
+ __device = torch.empty(1).cuda().device
74
+ else:
75
+ raise NotImplementedError
76
+
77
+
78
+ def is_master():
79
+ return __rank == 0
80
+
81
+
82
+ def is_local_master():
83
+ return __local_rank == 0
84
+
85
+
86
+ def new_group(ranks: List[int]):
87
+ if __initialized:
88
+ return tdist.new_group(ranks=ranks)
89
+ return None
90
+
91
+
92
+ def barrier():
93
+ if __initialized:
94
+ tdist.barrier()
95
+
96
+
97
+ def allreduce(t: torch.Tensor, async_op=False):
98
+ if __initialized:
99
+ if not t.is_cuda:
100
+ cu = t.detach().cuda()
101
+ ret = tdist.all_reduce(cu, async_op=async_op)
102
+ t.copy_(cu.cpu())
103
+ else:
104
+ ret = tdist.all_reduce(t, async_op=async_op)
105
+ return ret
106
+ return None
107
+
108
+
109
+ def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
110
+ if __initialized:
111
+ if not t.is_cuda:
112
+ t = t.cuda()
113
+ ls = [torch.empty_like(t) for _ in range(__world_size)]
114
+ tdist.all_gather(ls, t)
115
+ else:
116
+ ls = [t]
117
+ if cat:
118
+ ls = torch.cat(ls, dim=0)
119
+ return ls
120
+
121
+
122
+ def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
123
+ if __initialized:
124
+ if not t.is_cuda:
125
+ t = t.cuda()
126
+
127
+ t_size = torch.tensor(t.size(), device=t.device)
128
+ ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
129
+ tdist.all_gather(ls_size, t_size)
130
+
131
+ max_B = max(size[0].item() for size in ls_size)
132
+ pad = max_B - t_size[0].item()
133
+ if pad:
134
+ pad_size = (pad, *t.size()[1:])
135
+ t = torch.cat((t, t.new_empty(pad_size)), dim=0)
136
+
137
+ ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
138
+ tdist.all_gather(ls_padded, t)
139
+ ls = []
140
+ for t, size in zip(ls_padded, ls_size):
141
+ ls.append(t[:size[0].item()])
142
+ else:
143
+ ls = [t]
144
+ if cat:
145
+ ls = torch.cat(ls, dim=0)
146
+ return ls
147
+
148
+
149
+ def broadcast(t: torch.Tensor, src_rank) -> None:
150
+ if __initialized:
151
+ if not t.is_cuda:
152
+ cu = t.detach().cuda()
153
+ tdist.broadcast(cu, src=src_rank)
154
+ t.copy_(cu.cpu())
155
+ else:
156
+ tdist.broadcast(t, src=src_rank)
157
+
158
+
159
+ def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
160
+ if not initialized():
161
+ return torch.tensor([val]) if fmt is None else [fmt % val]
162
+
163
+ ts = torch.zeros(__world_size)
164
+ ts[__rank] = val
165
+ allreduce(ts)
166
+ if fmt is None:
167
+ return ts
168
+ return [fmt % v for v in ts.cpu().numpy().tolist()]
169
+
170
+
171
+ def master_only(func):
172
+ @functools.wraps(func)
173
+ def wrapper(*args, **kwargs):
174
+ force = kwargs.pop('force', False)
175
+ if force or is_master():
176
+ ret = func(*args, **kwargs)
177
+ else:
178
+ ret = None
179
+ barrier()
180
+ return ret
181
+ return wrapper
182
+
183
+
184
+ def local_master_only(func):
185
+ @functools.wraps(func)
186
+ def wrapper(*args, **kwargs):
187
+ force = kwargs.pop('force', False)
188
+ if force or is_local_master():
189
+ ret = func(*args, **kwargs)
190
+ else:
191
+ ret = None
192
+ barrier()
193
+ return ret
194
+ return wrapper
195
+
196
+
197
+ def for_visualize(func):
198
+ @functools.wraps(func)
199
+ def wrapper(*args, **kwargs):
200
+ if is_master():
201
+ # with torch.no_grad():
202
+ ret = func(*args, **kwargs)
203
+ else:
204
+ ret = None
205
+ return ret
206
+ return wrapper
207
+
208
+
209
+ def finalize():
210
+ if __initialized:
211
+ tdist.destroy_process_group()
models/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch.nn as nn
3
+
4
+ from .quant import VectorQuantizer2
5
+ from .var import VAR
6
+ from .vqvae import VQVAE
7
+
8
+
9
+ def build_vae_var(
10
+ # Shared args
11
+ device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
12
+ # VQVAE args
13
+ V=4096, Cvae=32, ch=160, share_quant_resi=4,
14
+ # VAR args
15
+ num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
16
+ flash_if_available=True, fused_if_available=True,
17
+ init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
18
+ ) -> Tuple[VQVAE, VAR]:
19
+ heads = depth
20
+ width = depth * 64
21
+ dpr = 0.1 * depth/24
22
+
23
+ # disable built-in initialization for speed
24
+ for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
25
+ setattr(clz, 'reset_parameters', lambda self: None)
26
+
27
+ # build models
28
+ vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
29
+ var_wo_ddp = VAR(
30
+ vae_local=vae_local,
31
+ num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
32
+ norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
33
+ attn_l2_norm=attn_l2_norm,
34
+ patch_nums=patch_nums,
35
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
36
+ ).to(device)
37
+ var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
38
+
39
+ return vae_local, var_wo_ddp
models/basic_vae.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ # this file only provides the 2 modules used in VQVAE
7
+ __all__ = ['Encoder', 'Decoder',]
8
+
9
+
10
+ """
11
+ References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
12
+ """
13
+ # swish
14
+ def nonlinearity(x):
15
+ return x * torch.sigmoid(x)
16
+
17
+
18
+ def Normalize(in_channels, num_groups=32):
19
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
20
+
21
+
22
+ class Upsample2x(nn.Module):
23
+ def __init__(self, in_channels):
24
+ super().__init__()
25
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
26
+
27
+ def forward(self, x):
28
+ return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
29
+
30
+
31
+ class Downsample2x(nn.Module):
32
+ def __init__(self, in_channels):
33
+ super().__init__()
34
+ self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
35
+
36
+ def forward(self, x):
37
+ return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
38
+
39
+
40
+ class ResnetBlock(nn.Module):
41
+ def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ out_channels = in_channels if out_channels is None else out_channels
45
+ self.out_channels = out_channels
46
+
47
+ self.norm1 = Normalize(in_channels)
48
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
49
+ self.norm2 = Normalize(out_channels)
50
+ self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
51
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
52
+ if self.in_channels != self.out_channels:
53
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
54
+ else:
55
+ self.nin_shortcut = nn.Identity()
56
+
57
+ def forward(self, x):
58
+ h = self.conv1(F.silu(self.norm1(x), inplace=True))
59
+ h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
60
+ return self.nin_shortcut(x) + h
61
+
62
+
63
+ class AttnBlock(nn.Module):
64
+ def __init__(self, in_channels):
65
+ super().__init__()
66
+ self.C = in_channels
67
+
68
+ self.norm = Normalize(in_channels)
69
+ self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
70
+ self.w_ratio = int(in_channels) ** (-0.5)
71
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
72
+
73
+ def forward(self, x):
74
+ qkv = self.qkv(self.norm(x))
75
+ B, _, H, W = qkv.shape # should be B,3C,H,W
76
+ C = self.C
77
+ q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
78
+
79
+ # compute attention
80
+ q = q.view(B, C, H * W).contiguous()
81
+ q = q.permute(0, 2, 1).contiguous() # B,HW,C
82
+ k = k.view(B, C, H * W).contiguous() # B,C,HW
83
+ w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
84
+ w = F.softmax(w, dim=2)
85
+
86
+ # attend to values
87
+ v = v.view(B, C, H * W).contiguous()
88
+ w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
89
+ h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
90
+ h = h.view(B, C, H, W).contiguous()
91
+
92
+ return x + self.proj_out(h)
93
+
94
+
95
+ def make_attn(in_channels, using_sa=True):
96
+ return AttnBlock(in_channels) if using_sa else nn.Identity()
97
+
98
+
99
+ class Encoder(nn.Module):
100
+ def __init__(
101
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
102
+ dropout=0.0, in_channels=3,
103
+ z_channels, double_z=False, using_sa=True, using_mid_sa=True,
104
+ ):
105
+ super().__init__()
106
+ self.ch = ch
107
+ self.num_resolutions = len(ch_mult)
108
+ self.downsample_ratio = 2 ** (self.num_resolutions - 1)
109
+ self.num_res_blocks = num_res_blocks
110
+ self.in_channels = in_channels
111
+
112
+ # downsampling
113
+ self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
114
+
115
+ in_ch_mult = (1,) + tuple(ch_mult)
116
+ self.down = nn.ModuleList()
117
+ for i_level in range(self.num_resolutions):
118
+ block = nn.ModuleList()
119
+ attn = nn.ModuleList()
120
+ block_in = ch * in_ch_mult[i_level]
121
+ block_out = ch * ch_mult[i_level]
122
+ for i_block in range(self.num_res_blocks):
123
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
124
+ block_in = block_out
125
+ if i_level == self.num_resolutions - 1 and using_sa:
126
+ attn.append(make_attn(block_in, using_sa=True))
127
+ down = nn.Module()
128
+ down.block = block
129
+ down.attn = attn
130
+ if i_level != self.num_resolutions - 1:
131
+ down.downsample = Downsample2x(block_in)
132
+ self.down.append(down)
133
+
134
+ # middle
135
+ self.mid = nn.Module()
136
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
137
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
138
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
139
+
140
+ # end
141
+ self.norm_out = Normalize(block_in)
142
+ self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
143
+
144
+ def forward(self, x):
145
+ # downsampling
146
+ h = self.conv_in(x)
147
+ for i_level in range(self.num_resolutions):
148
+ for i_block in range(self.num_res_blocks):
149
+ h = self.down[i_level].block[i_block](h)
150
+ if len(self.down[i_level].attn) > 0:
151
+ h = self.down[i_level].attn[i_block](h)
152
+ if i_level != self.num_resolutions - 1:
153
+ h = self.down[i_level].downsample(h)
154
+
155
+ # middle
156
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
157
+
158
+ # end
159
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
160
+ return h
161
+
162
+
163
+ class Decoder(nn.Module):
164
+ def __init__(
165
+ self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
166
+ dropout=0.0, in_channels=3, # in_channels: raw img channels
167
+ z_channels, using_sa=True, using_mid_sa=True,
168
+ ):
169
+ super().__init__()
170
+ self.ch = ch
171
+ self.num_resolutions = len(ch_mult)
172
+ self.num_res_blocks = num_res_blocks
173
+ self.in_channels = in_channels
174
+
175
+ # compute in_ch_mult, block_in and curr_res at lowest res
176
+ in_ch_mult = (1,) + tuple(ch_mult)
177
+ block_in = ch * ch_mult[self.num_resolutions - 1]
178
+
179
+ # z to block_in
180
+ self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
181
+
182
+ # middle
183
+ self.mid = nn.Module()
184
+ self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
185
+ self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
186
+ self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
187
+
188
+ # upsampling
189
+ self.up = nn.ModuleList()
190
+ for i_level in reversed(range(self.num_resolutions)):
191
+ block = nn.ModuleList()
192
+ attn = nn.ModuleList()
193
+ block_out = ch * ch_mult[i_level]
194
+ for i_block in range(self.num_res_blocks + 1):
195
+ block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
196
+ block_in = block_out
197
+ if i_level == self.num_resolutions-1 and using_sa:
198
+ attn.append(make_attn(block_in, using_sa=True))
199
+ up = nn.Module()
200
+ up.block = block
201
+ up.attn = attn
202
+ if i_level != 0:
203
+ up.upsample = Upsample2x(block_in)
204
+ self.up.insert(0, up) # prepend to get consistent order
205
+
206
+ # end
207
+ self.norm_out = Normalize(block_in)
208
+ self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
209
+
210
+ def forward(self, z):
211
+ # z to block_in
212
+ # middle
213
+ h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
214
+
215
+ # upsampling
216
+ for i_level in reversed(range(self.num_resolutions)):
217
+ for i_block in range(self.num_res_blocks + 1):
218
+ h = self.up[i_level].block[i_block](h)
219
+ if len(self.up[i_level].attn) > 0:
220
+ h = self.up[i_level].attn[i_block](h)
221
+ if i_level != 0:
222
+ h = self.up[i_level].upsample(h)
223
+
224
+ # end
225
+ h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
226
+ return h
models/basic_var.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from models.helpers import DropPath, drop_path
8
+
9
+
10
+ # this file only provides the 3 blocks used in VAR transformer
11
+ __all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
12
+
13
+
14
+ # automatically import fused operators
15
+ dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
16
+ try:
17
+ from flash_attn.ops.layer_norm import dropout_add_layer_norm
18
+ from flash_attn.ops.fused_dense import fused_mlp_func
19
+ except ImportError: pass
20
+ # automatically import faster attention implementations
21
+ try: from xformers.ops import memory_efficient_attention
22
+ except ImportError: pass
23
+ try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
24
+ except ImportError: pass
25
+ try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
26
+ except ImportError:
27
+ def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
28
+ attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
29
+ if attn_mask is not None: attn.add_(attn_mask)
30
+ return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
31
+
32
+
33
+ class FFN(nn.Module):
34
+ def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
35
+ super().__init__()
36
+ self.fused_mlp_func = fused_mlp_func if fused_if_available else None
37
+ out_features = out_features or in_features
38
+ hidden_features = hidden_features or in_features
39
+ self.fc1 = nn.Linear(in_features, hidden_features)
40
+ self.act = nn.GELU(approximate='tanh')
41
+ self.fc2 = nn.Linear(hidden_features, out_features)
42
+ self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
43
+
44
+ def forward(self, x):
45
+ if self.fused_mlp_func is not None:
46
+ return self.drop(self.fused_mlp_func(
47
+ x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
48
+ activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
49
+ heuristic=0, process_group=None,
50
+ ))
51
+ else:
52
+ return self.drop(self.fc2( self.act(self.fc1(x)) ))
53
+
54
+ def extra_repr(self) -> str:
55
+ return f'fused_mlp_func={self.fused_mlp_func is not None}'
56
+
57
+
58
+ class SelfAttention(nn.Module):
59
+ def __init__(
60
+ self, block_idx, embed_dim=768, num_heads=12,
61
+ attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
62
+ ):
63
+ super().__init__()
64
+ assert embed_dim % num_heads == 0
65
+ self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
66
+ self.attn_l2_norm = attn_l2_norm
67
+ if self.attn_l2_norm:
68
+ self.scale = 1
69
+ self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
70
+ self.max_scale_mul = torch.log(torch.tensor(100)).item()
71
+ else:
72
+ self.scale = 0.25 / math.sqrt(self.head_dim)
73
+
74
+ self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
75
+ self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
76
+ self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
77
+
78
+ self.proj = nn.Linear(embed_dim, embed_dim)
79
+ self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
80
+ self.attn_drop: float = attn_drop
81
+ self.using_flash = flash_if_available and flash_attn_func is not None
82
+ self.using_xform = flash_if_available and memory_efficient_attention is not None
83
+
84
+ # only used during inference
85
+ self.caching, self.cached_k, self.cached_v = False, None, None
86
+
87
+ def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
88
+
89
+ # NOTE: attn_bias is None during inference because kv cache is enabled
90
+ def forward(self, x, attn_bias):
91
+ B, L, C = x.shape
92
+
93
+ qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
94
+ main_type = qkv.dtype
95
+ # qkv: BL3Hc
96
+
97
+ using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
98
+ if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
99
+ else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
100
+
101
+ if self.attn_l2_norm:
102
+ scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
103
+ if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
104
+ q = F.normalize(q, dim=-1).mul(scale_mul)
105
+ k = F.normalize(k, dim=-1)
106
+
107
+ if self.caching:
108
+ if self.cached_k is None: self.cached_k = k; self.cached_v = v
109
+ else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
110
+
111
+ dropout_p = self.attn_drop if self.training else 0.0
112
+ if using_flash:
113
+ oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
114
+ elif self.using_xform:
115
+ oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
116
+ else:
117
+ oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
118
+
119
+ return self.proj_drop(self.proj(oup))
120
+ # attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
121
+ # attn = self.attn_drop(attn.softmax(dim=-1))
122
+ # oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
123
+
124
+ def extra_repr(self) -> str:
125
+ return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
126
+
127
+
128
+ class AdaLNSelfAttn(nn.Module):
129
+ def __init__(
130
+ self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
131
+ num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
132
+ flash_if_available=False, fused_if_available=True,
133
+ ):
134
+ super(AdaLNSelfAttn, self).__init__()
135
+ self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
136
+ self.C, self.D = embed_dim, cond_dim
137
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
138
+ self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
139
+ self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
140
+
141
+ self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
142
+ self.shared_aln = shared_aln
143
+ if self.shared_aln:
144
+ self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
145
+ else:
146
+ lin = nn.Linear(cond_dim, 6*embed_dim)
147
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
148
+
149
+ self.fused_add_norm_fn = None
150
+
151
+ # NOTE: attn_bias is None during inference because kv cache is enabled
152
+ def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
153
+ if self.shared_aln:
154
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
155
+ else:
156
+ gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
157
+ x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
158
+ x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
159
+ return x
160
+
161
+ def extra_repr(self) -> str:
162
+ return f'shared_aln={self.shared_aln}'
163
+
164
+
165
+ class AdaLNBeforeHead(nn.Module):
166
+ def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
167
+ super().__init__()
168
+ self.C, self.D = C, D
169
+ self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
170
+ self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
171
+
172
+ def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
173
+ scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
174
+ return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
models/helpers.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn as nn
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
7
+ B, l, V = logits_BlV.shape
8
+ if top_k > 0:
9
+ idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
10
+ logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
11
+ if top_p > 0:
12
+ sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
13
+ sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
14
+ sorted_idx_to_remove[..., -1:] = False
15
+ logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
16
+ # sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
17
+ replacement = num_samples >= 0
18
+ num_samples = abs(num_samples)
19
+ return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
20
+
21
+
22
+ def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
23
+ if rng is None:
24
+ return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
25
+
26
+ gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
27
+ gumbels = (logits + gumbels) / tau
28
+ y_soft = gumbels.softmax(dim)
29
+
30
+ if hard:
31
+ index = y_soft.max(dim, keepdim=True)[1]
32
+ y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
33
+ ret = y_hard - y_soft.detach() + y_soft
34
+ else:
35
+ ret = y_soft
36
+ return ret
37
+
38
+
39
+ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
40
+ if drop_prob == 0. or not training: return x
41
+ keep_prob = 1 - drop_prob
42
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
43
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
44
+ if keep_prob > 0.0 and scale_by_keep:
45
+ random_tensor.div_(keep_prob)
46
+ return x * random_tensor
47
+
48
+
49
+ class DropPath(nn.Module): # taken from timm
50
+ def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
51
+ super(DropPath, self).__init__()
52
+ self.drop_prob = drop_prob
53
+ self.scale_by_keep = scale_by_keep
54
+
55
+ def forward(self, x):
56
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
57
+
58
+ def extra_repr(self):
59
+ return f'(drop_prob=...)'
models/quant.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Sequence, Tuple, Union
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch import distributed as tdist, nn as nn
6
+ from torch.nn import functional as F
7
+
8
+ import dist
9
+
10
+ # this file only provides the VectorQuantizer2 used in VQVAE
11
+ __all__ = ['VectorQuantizer2', ]
12
+
13
+
14
+ class VectorQuantizer2(nn.Module):
15
+ # VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
16
+ def __init__(
17
+ self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
18
+ default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
19
+ ):
20
+ super().__init__()
21
+ self.vocab_size: int = vocab_size
22
+ self.Cvae: int = Cvae
23
+ self.using_znorm: bool = using_znorm
24
+ self.v_patch_nums: Tuple[int] = v_patch_nums
25
+
26
+ self.quant_resi_ratio = quant_resi
27
+ if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
28
+ self.quant_resi = PhiNonShared(
29
+ [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
30
+ range(default_qresi_counts or len(self.v_patch_nums))])
31
+ elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
32
+ self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
33
+ else: # partially shared: \phi_{1 to share_quant_resi} for K scales
34
+ self.quant_resi = PhiPartiallyShared(nn.ModuleList(
35
+ [(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
36
+ range(share_quant_resi)]))
37
+
38
+ self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
39
+ self.record_hit = 0
40
+
41
+ self.beta: float = beta
42
+ self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
43
+
44
+ # only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
45
+ self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
46
+
47
+ def eini(self, eini):
48
+ if eini > 0:
49
+ nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
50
+ elif eini < 0:
51
+ self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
52
+
53
+ def extra_repr(self) -> str:
54
+ return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
55
+
56
+ # ===================== `forward` is only used in VAE training =====================
57
+ def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
58
+ dtype = f_BChw.dtype
59
+ if dtype != torch.float32: f_BChw = f_BChw.float()
60
+ B, C, H, W = f_BChw.shape
61
+ f_no_grad = f_BChw.detach()
62
+
63
+ f_rest = f_no_grad.clone()
64
+ f_hat = torch.zeros_like(f_rest)
65
+
66
+ with torch.cuda.amp.autocast(enabled=False):
67
+ mean_vq_loss: torch.Tensor = 0.0
68
+ vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
69
+ SN = len(self.v_patch_nums)
70
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
71
+ # find the nearest embedding
72
+ if self.using_znorm:
73
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
74
+ C) if (
75
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
76
+ rest_NC = F.normalize(rest_NC, dim=-1)
77
+ idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
78
+ else:
79
+ rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
80
+ C) if (
81
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
82
+ d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(
83
+ self.embedding.weight.data.square(), dim=1, keepdim=False)
84
+ d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
85
+ idx_N = torch.argmin(d_no_grad, dim=1)
86
+
87
+ hit_V = idx_N.bincount(minlength=self.vocab_size).float()
88
+ if self.training:
89
+ if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
90
+
91
+ # calc loss
92
+ idx_Bhw = idx_N.view(B, pn, pn)
93
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
94
+ mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(
95
+ idx_Bhw).permute(0, 3, 1, 2).contiguous()
96
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
97
+ f_hat = f_hat + h_BChw
98
+ f_rest -= h_BChw
99
+
100
+ if self.training and dist.initialized():
101
+ handler.wait()
102
+ if self.record_hit == 0:
103
+ self.ema_vocab_hit_SV[si].copy_(hit_V)
104
+ elif self.record_hit < 100:
105
+ self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
106
+ else:
107
+ self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
108
+ self.record_hit += 1
109
+ vocab_hit_V.add_(hit_V)
110
+ mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
111
+
112
+ mean_vq_loss *= 1. / SN
113
+ f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
114
+
115
+ margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
116
+ # margin = pn*pn / 100
117
+ if ret_usages:
118
+ usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in
119
+ enumerate(self.v_patch_nums)]
120
+ else:
121
+ usages = None
122
+ return f_hat, usages, mean_vq_loss
123
+
124
+ # ===================== `forward` is only used in VAE training =====================
125
+
126
+ def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[
127
+ List[torch.Tensor], torch.Tensor]:
128
+ ls_f_hat_BChw = []
129
+ B = ms_h_BChw[0].shape[0]
130
+ H = W = self.v_patch_nums[-1]
131
+ SN = len(self.v_patch_nums)
132
+ if all_to_max_scale:
133
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
134
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
135
+ h_BChw = ms_h_BChw[si]
136
+ if si < len(self.v_patch_nums) - 1:
137
+ h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bilinear')
138
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
139
+ f_hat.add_(h_BChw)
140
+ if last_one:
141
+ ls_f_hat_BChw = f_hat
142
+ else:
143
+ ls_f_hat_BChw.append(f_hat.clone())
144
+ else:
145
+ # WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
146
+ # WARNING: this should only be used for experimental purpose
147
+ f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0],
148
+ dtype=torch.float32)
149
+ for si, pn in enumerate(self.v_patch_nums): # from small to large
150
+ f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bilinear')
151
+ h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
152
+ f_hat.add_(h_BChw)
153
+ if last_one:
154
+ ls_f_hat_BChw = f_hat
155
+ else:
156
+ ls_f_hat_BChw.append(f_hat)
157
+
158
+ return ls_f_hat_BChw
159
+
160
+ def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool,
161
+ v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[
162
+ Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
163
+ B, C, H, W = f_BChw.shape
164
+ f_no_grad = f_BChw.detach()
165
+ f_rest = f_no_grad.clone()
166
+ f_hat = torch.zeros_like(f_rest)
167
+
168
+ f_hat_or_idx_Bl: List[torch.Tensor] = []
169
+
170
+ patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in
171
+ (v_patch_nums or self.v_patch_nums)] # from small to large
172
+ assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
173
+
174
+ SN = len(patch_hws)
175
+ for si, (ph, pw) in enumerate(patch_hws): # from small to large
176
+ if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
177
+ # find the nearest embedding
178
+ z_NC = F.interpolate(f_rest, size=(ph, pw), mode='bilinear').permute(0, 2, 3, 1).reshape(-1, C) if (
179
+ si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
180
+ if self.using_znorm:
181
+ z_NC = F.normalize(z_NC, dim=-1)
182
+ idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
183
+ else:
184
+ d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
185
+ self.embedding.weight.data.square(), dim=1, keepdim=False)
186
+ d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
187
+ idx_N = torch.argmin(d_no_grad, dim=1)
188
+
189
+ idx_Bhw = idx_N.view(B, ph, pw)
190
+ h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
191
+ mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(idx_Bhw).permute(
192
+ 0, 3, 1, 2).contiguous()
193
+ h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
194
+ f_hat.add_(h_BChw)
195
+ f_rest.sub_(h_BChw)
196
+ f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw))
197
+
198
+ return f_hat_or_idx_Bl
199
+
200
+ # ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
201
+ def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
202
+ next_scales = []
203
+ B = gt_ms_idx_Bl[0].shape[0]
204
+ C = self.Cvae
205
+ H = W = self.v_patch_nums[-1]
206
+ SN = len(self.v_patch_nums)
207
+
208
+ f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
209
+ pn_next: int = self.v_patch_nums[0]
210
+ for si in range(SN - 1):
211
+ if self.prog_si == 0 or (
212
+ 0 <= self.prog_si - 1 < si): break # progressive training: not supported yet, prog_si always -1
213
+ h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next),
214
+ size=(H, W), mode='bilinear')
215
+ f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
216
+ pn_next = self.v_patch_nums[si + 1]
217
+ next_scales.append(
218
+ F.interpolate(f_hat, size=(pn_next, pn_next), mode='bilinear').view(B, C, -1).transpose(1, 2))
219
+ return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
220
+
221
+ # ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
222
+ def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[
223
+ Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
224
+ HW = self.v_patch_nums[-1]
225
+ if si != SN - 1:
226
+ h = self.quant_resi[si / (SN - 1)](
227
+ F.interpolate(h_BChw, size=(HW, HW), mode='bilinear')) # conv after upsample
228
+ f_hat.add_(h)
229
+ return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
230
+ mode='bilinear')
231
+ else:
232
+ h = self.quant_resi[si / (SN - 1)](h_BChw)
233
+ f_hat.add_(h)
234
+ return f_hat, f_hat
235
+
236
+
237
+ class Phi(nn.Conv2d):
238
+ def __init__(self, embed_dim, quant_resi):
239
+ ks = 3
240
+ super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
241
+ self.resi_ratio = abs(quant_resi)
242
+
243
+ def forward(self, h_BChw):
244
+ return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
245
+
246
+
247
+ class PhiShared(nn.Module):
248
+ def __init__(self, qresi: Phi):
249
+ super().__init__()
250
+ self.qresi: Phi = qresi
251
+
252
+ def __getitem__(self, _) -> Phi:
253
+ return self.qresi
254
+
255
+
256
+ class PhiPartiallyShared(nn.Module):
257
+ def __init__(self, qresi_ls: nn.ModuleList):
258
+ super().__init__()
259
+ self.qresi_ls = qresi_ls
260
+ K = len(qresi_ls)
261
+ self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
262
+
263
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
264
+ return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
265
+
266
+ def extra_repr(self) -> str:
267
+ return f'ticks={self.ticks}'
268
+
269
+
270
+ class PhiNonShared(nn.ModuleList):
271
+ def __init__(self, qresi: List):
272
+ super().__init__(qresi)
273
+ # self.qresi = qresi
274
+ K = len(qresi)
275
+ self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
276
+
277
+ def __getitem__(self, at_from_0_to_1: float) -> Phi:
278
+ return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
279
+
280
+ def extra_repr(self) -> str:
281
+ return f'ticks={self.ticks}'
models/var.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from functools import partial
3
+ from typing import Optional, Tuple, Union
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+ import dist
10
+ from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
11
+ from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
12
+ from models.vqvae import VQVAE, VectorQuantizer2
13
+
14
+
15
+ class SharedAdaLin(nn.Linear):
16
+ def forward(self, cond_BD):
17
+ C = self.weight.shape[0] // 6
18
+ return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
19
+
20
+
21
+ class VAR(nn.Module):
22
+ def __init__(
23
+ self, vae_local: VQVAE,
24
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
25
+ drop_path_rate=0.,
26
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
27
+ attn_l2_norm=False,
28
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
29
+ flash_if_available=True, fused_if_available=True,
30
+ ):
31
+ super().__init__()
32
+ # 0. hyperparameters
33
+ assert embed_dim % num_heads == 0
34
+ self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
35
+ self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
36
+
37
+ self.cond_drop_rate = cond_drop_rate
38
+ self.prog_si = -1 # progressive training
39
+
40
+ self.patch_nums: Tuple[int] = patch_nums
41
+ self.L = sum(pn ** 2 for pn in self.patch_nums)
42
+ self.first_l = self.patch_nums[0] ** 2
43
+ self.begin_ends = []
44
+ cur = 0
45
+ for i, pn in enumerate(self.patch_nums):
46
+ self.begin_ends.append((cur, cur + pn ** 2))
47
+ cur += pn ** 2
48
+
49
+ self.num_stages_minus_1 = len(self.patch_nums) - 1
50
+ self.rng = torch.Generator(device="mps")
51
+
52
+ # 1. input (word) embedding
53
+ quant: VectorQuantizer2 = vae_local.quantize
54
+ self.vae_proxy: Tuple[VQVAE] = (vae_local,)
55
+ self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
56
+ self.word_embed = nn.Linear(self.Cvae, self.C)
57
+
58
+ # 2. class embedding
59
+ init_std = math.sqrt(1 / self.C / 3)
60
+ self.num_classes = num_classes
61
+ self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32,
62
+ device=dist.get_device())
63
+ self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
64
+ nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
65
+ self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
66
+ nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
67
+
68
+ # 3. absolute position embedding
69
+ pos_1LC = []
70
+ for i, pn in enumerate(self.patch_nums):
71
+ pe = torch.empty(1, pn * pn, self.C)
72
+ nn.init.trunc_normal_(pe, mean=0, std=init_std)
73
+ pos_1LC.append(pe)
74
+ pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
75
+ assert tuple(pos_1LC.shape) == (1, self.L, self.C)
76
+ self.pos_1LC = nn.Parameter(pos_1LC)
77
+ # level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
78
+ self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
79
+ nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
80
+
81
+ # 4. backbone blocks
82
+ self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False),
83
+ SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity()
84
+
85
+ norm_layer = partial(nn.LayerNorm, eps=norm_eps)
86
+ self.drop_path_rate = drop_path_rate
87
+ dpr = [x.item() for x in
88
+ torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
89
+ self.blocks = nn.ModuleList([
90
+ AdaLNSelfAttn(
91
+ cond_dim=self.D, shared_aln=shared_aln,
92
+ block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
93
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx],
94
+ last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
95
+ attn_l2_norm=attn_l2_norm,
96
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
97
+ )
98
+ for block_idx in range(depth)
99
+ ])
100
+
101
+ fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
102
+ self.using_fused_add_norm_fn = any(fused_add_norm_fns)
103
+ print(
104
+ f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
105
+ f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
106
+ f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
107
+ end='\n\n', flush=True
108
+ )
109
+
110
+ # 5. attention mask used in training (for masking out the future)
111
+ # it won't be used in inference, since kv cache is enabled
112
+ d: torch.Tensor = torch.cat([torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L,
113
+ 1)
114
+ dT = d.transpose(1, 2) # dT: 11L
115
+ lvl_1L = dT[:, 0].contiguous()
116
+ self.register_buffer('lvl_1L', lvl_1L)
117
+ attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
118
+ self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
119
+
120
+ # 6. classifier head
121
+ self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
122
+ self.head = nn.Linear(self.C, self.V)
123
+
124
+ def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
125
+ cond_BD: Optional[torch.Tensor]):
126
+ if not isinstance(h_or_h_and_residual, torch.Tensor):
127
+ h, resi = h_or_h_and_residual # fused_add_norm must be used
128
+ h = resi + self.blocks[-1].drop_path(h)
129
+ else: # fused_add_norm is not used
130
+ h = h_or_h_and_residual
131
+ return self.head(self.head_nm(h.float(), cond_BD).float()).float()
132
+
133
+ @torch.no_grad()
134
+ def autoregressive_infer_cfg(
135
+ self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
136
+ delta_condition: torch.Tensor, alpha: float, beta: float,
137
+ g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
138
+ more_smooth=False,
139
+ ) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
140
+ """
141
+ Generate images using autoregressive inference with classifier-free guidance.
142
+ :param B: batch size
143
+ :param label_B: class labels; if None, randomly sampled
144
+ :param delta_condition: tensor of shape (B, D)
145
+ :param alpha: scalar weight for class embedding
146
+ :param beta: scalar weight for delta_condition
147
+ :param g_seed: random seed
148
+ :param cfg: classifier-free guidance ratio
149
+ :param top_k: top-k sampling
150
+ :param top_p: top-p sampling
151
+ :param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
152
+ :return: reconstructed images (B, 3, H, W)
153
+ """
154
+ if g_seed is None:
155
+ rng = None
156
+ else:
157
+ self.rng.manual_seed(g_seed)
158
+ rng = self.rng
159
+
160
+ device = self.lvl_1L.device
161
+ if label_B is None:
162
+ label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
163
+ elif isinstance(label_B, int):
164
+ label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=device)
165
+
166
+ # Prepare labels for conditioned and unconditioned versions
167
+ label_B_cond = label_B
168
+ label_B_uncond = torch.full_like(label_B, fill_value=self.num_classes)
169
+ label_B = torch.cat((label_B_cond, label_B_uncond), dim=0) # shape (2B,)
170
+
171
+ # Prepare delta_condition for conditioned and unconditioned versions
172
+ delta_condition_uncond = torch.zeros_like(delta_condition)
173
+ delta_condition = torch.cat((delta_condition, delta_condition_uncond), dim=0) # shape (2B, D)
174
+
175
+ class_emb = self.class_emb(label_B) # shape (2B, D)
176
+ cond_BD = alpha * class_emb + beta * delta_condition # shape (2B, D)
177
+
178
+ sos = cond_BD.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1)
179
+
180
+ lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
181
+ next_token_map = sos + lvl_pos[:, :self.first_l]
182
+
183
+ cur_L = 0
184
+ f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
185
+
186
+ for b in self.blocks:
187
+ b.attn.kv_caching(True)
188
+ for si, pn in enumerate(self.patch_nums): # si: i-th segment
189
+ ratio = si / self.num_stages_minus_1
190
+ cur_L += pn * pn
191
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
192
+ x = next_token_map
193
+
194
+ for b in self.blocks:
195
+ x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
196
+ logits_BlV = self.get_logits(x, cond_BD)
197
+
198
+ t = cfg * ratio
199
+ logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
200
+
201
+ idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
202
+ if not more_smooth: # this is the default case
203
+ h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
204
+ else: # not used when evaluating FID/IS/Precision/Recall
205
+ gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
206
+ h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ \
207
+ self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
208
+
209
+ h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
210
+ f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums),
211
+ f_hat, h_BChw)
212
+ if si != self.num_stages_minus_1: # prepare for next stage
213
+ next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
214
+ next_token_map = self.word_embed(next_token_map) + lvl_pos[:,
215
+ cur_L:cur_L + self.patch_nums[si + 1] ** 2]
216
+ next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
217
+
218
+ for b in self.blocks:
219
+ b.attn.kv_caching(False)
220
+ return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
221
+
222
+ def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor, delta_condition: torch.Tensor,
223
+ alpha: float, beta: float) -> torch.Tensor:
224
+ """
225
+ :param label_B: label_B
226
+ :param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
227
+ :param delta_condition: tensor of shape (B, D)
228
+ :param alpha: scalar weight for class embedding
229
+ :param beta: scalar weight for delta_condition
230
+ :return: logits BLV, V is vocab_size
231
+ """
232
+ bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
233
+ B = x_BLCv_wo_first_l.shape[0]
234
+ with torch.cuda.amp.autocast(enabled=False):
235
+ # Implement conditional dropout
236
+ drop_mask = torch.rand(B, device=label_B.device) < self.cond_drop_rate
237
+ label_B_dropped = torch.where(drop_mask, self.num_classes, label_B)
238
+ delta_condition_dropped = delta_condition.clone()
239
+ delta_condition_dropped[drop_mask] = 0.0 # Drop delta_condition
240
+
241
+ class_emb = self.class_emb(label_B_dropped)
242
+ cond_BD = alpha * class_emb + beta * delta_condition_dropped
243
+
244
+ sos = cond_BD.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
245
+
246
+ if self.prog_si == 0:
247
+ x_BLC = sos
248
+ else:
249
+ x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
250
+ x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
251
+
252
+ attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
253
+ cond_BD_or_gss = self.shared_ada_lin(cond_BD)
254
+
255
+ # hack: get the dtype if mixed precision is used
256
+ temp = x_BLC.new_ones(8, 8)
257
+ main_type = torch.matmul(temp, temp).dtype
258
+
259
+ x_BLC = x_BLC.to(dtype=main_type)
260
+ cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
261
+ attn_bias = attn_bias.to(dtype=main_type)
262
+
263
+ AdaLNSelfAttn.forward
264
+ for i, b in enumerate(self.blocks):
265
+ x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
266
+ x_BLC = self.get_logits(x_BLC.float(), cond_BD)
267
+
268
+ if self.prog_si == 0:
269
+ if isinstance(self.word_embed, nn.Linear):
270
+ x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
271
+ else:
272
+ s = 0
273
+ for p in self.word_embed.parameters():
274
+ if p.requires_grad:
275
+ s += p.view(-1)[0] * 0
276
+ x_BLC[0, 0, 0] += s
277
+ return x_BLC # logits BLV, V is vocab_size
278
+
279
+ def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
280
+ if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
281
+
282
+ print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
283
+ for m in self.modules():
284
+ with_weight = hasattr(m, 'weight') and m.weight is not None
285
+ with_bias = hasattr(m, 'bias') and m.bias is not None
286
+ if isinstance(m, nn.Linear):
287
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
288
+ if with_bias: m.bias.data.zero_()
289
+ elif isinstance(m, nn.Embedding):
290
+ nn.init.trunc_normal_(m.weight.data, std=init_std)
291
+ if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
292
+ elif isinstance(m, (
293
+ nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm,
294
+ nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
295
+ if with_weight: m.weight.data.fill_(1.)
296
+ if with_bias: m.bias.data.zero_()
297
+ # conv: VAR has no conv, only VQVAE has conv
298
+ elif isinstance(m, (
299
+ nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
300
+ if conv_std_or_gain > 0:
301
+ nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
302
+ else:
303
+ nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
304
+ if with_bias: m.bias.data.zero_()
305
+
306
+ if init_head >= 0:
307
+ if isinstance(self.head, nn.Linear):
308
+ self.head.weight.data.mul_(init_head)
309
+ self.head.bias.data.zero_()
310
+ elif isinstance(self.head, nn.Sequential):
311
+ self.head[-1].weight.data.mul_(init_head)
312
+ self.head[-1].bias.data.zero_()
313
+
314
+ if isinstance(self.head_nm, AdaLNBeforeHead):
315
+ self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
316
+ if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
317
+ self.head_nm.ada_lin[-1].bias.data.zero_()
318
+
319
+ depth = len(self.blocks)
320
+ for block_idx, sab in enumerate(self.blocks):
321
+ sab: AdaLNSelfAttn
322
+ sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
323
+ sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
324
+ if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
325
+ nn.init.ones_(sab.ffn.fcg.bias)
326
+ nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
327
+ if hasattr(sab, 'ada_lin'):
328
+ sab.ada_lin[-1].weight.data[2 * self.C:].mul_(init_adaln)
329
+ sab.ada_lin[-1].weight.data[:2 * self.C].mul_(init_adaln_gamma)
330
+ if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
331
+ sab.ada_lin[-1].bias.data.zero_()
332
+ elif hasattr(sab, 'ada_gss'):
333
+ sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
334
+ sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
335
+
336
+ def extra_repr(self):
337
+ return f'drop_path_rate={self.drop_path_rate:g}'
338
+
339
+
340
+ class VARHF(VAR, PyTorchModelHubMixin):
341
+ def __init__(
342
+ self,
343
+ vae_kwargs,
344
+ num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
345
+ drop_path_rate=0.,
346
+ norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
347
+ attn_l2_norm=False,
348
+ patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
349
+ flash_if_available=True, fused_if_available=True,
350
+ ):
351
+ vae_local = VQVAE(**vae_kwargs)
352
+ super().__init__(
353
+ vae_local=vae_local,
354
+ num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
355
+ drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
356
+ norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
357
+ attn_l2_norm=attn_l2_norm,
358
+ patch_nums=patch_nums,
359
+ flash_if_available=flash_if_available, fused_if_available=fused_if_available,
360
+ )
models/vqvae.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ References:
3
+ - VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
4
+ - GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
5
+ - VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
6
+ """
7
+ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+
12
+ from .basic_vae import Decoder, Encoder
13
+ from .quant import VectorQuantizer2
14
+
15
+
16
+ class VQVAE(nn.Module):
17
+ def __init__(
18
+ self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
19
+ beta=0.25, # commitment loss weight
20
+ using_znorm=False, # whether to normalize when computing the nearest neighbors
21
+ quant_conv_ks=3, # quant conv kernel size
22
+ quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
23
+ share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
24
+ default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
25
+ v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
26
+ test_mode=True,
27
+ ):
28
+ super().__init__()
29
+ self.test_mode = test_mode
30
+ self.V, self.Cvae = vocab_size, z_channels
31
+ # ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
32
+ ddconfig = dict(
33
+ dropout=dropout, ch=ch, z_channels=z_channels,
34
+ in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
35
+ using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
36
+ # resamp_with_conv=True, # always True, removed.
37
+ )
38
+ ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
39
+ self.encoder = Encoder(double_z=False, **ddconfig)
40
+ self.decoder = Decoder(**ddconfig)
41
+
42
+ self.vocab_size = vocab_size
43
+ self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
44
+ self.quantize: VectorQuantizer2 = VectorQuantizer2(
45
+ vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
46
+ default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
47
+ )
48
+ self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
49
+ self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
50
+
51
+ if self.test_mode:
52
+ self.eval()
53
+ [p.requires_grad_(False) for p in self.parameters()]
54
+
55
+ # ===================== `forward` is only used in VAE training =====================
56
+ def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
57
+ VectorQuantizer2.forward
58
+ f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
59
+ return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
60
+ # ===================== `forward` is only used in VAE training =====================
61
+
62
+ def fhat_to_img(self, f_hat: torch.Tensor):
63
+ return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
64
+
65
+ def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
66
+ f = self.quant_conv(self.encoder(inp_img_no_grad))
67
+ return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
68
+
69
+ def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
70
+ B = ms_idx_Bl[0].shape[0]
71
+ ms_h_BChw = []
72
+ for idx_Bl in ms_idx_Bl:
73
+ l = idx_Bl.shape[1]
74
+ pn = round(l ** 0.5)
75
+ ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
76
+ return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
77
+
78
+ def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
79
+ if last_one:
80
+ return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
81
+ else:
82
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
83
+
84
+ def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
85
+ f = self.quant_conv(self.encoder(x))
86
+ ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
87
+ if last_one:
88
+ return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
89
+ else:
90
+ return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
91
+
92
+ def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
93
+ if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
94
+ state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
95
+ return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
utils/amp_sc.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import List, Optional, Tuple, Union
3
+
4
+ import torch
5
+
6
+
7
+ class NullCtx:
8
+ def __enter__(self):
9
+ pass
10
+
11
+ def __exit__(self, exc_type, exc_val, exc_tb):
12
+ pass
13
+
14
+
15
+ class AmpOptimizer:
16
+ def __init__(
17
+ self,
18
+ mixed_precision: int,
19
+ optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
20
+ grad_clip: float, n_gradient_accumulation: int = 1,
21
+ ):
22
+ self.enable_amp = mixed_precision > 0
23
+ self.using_fp16_rather_bf16 = mixed_precision == 1
24
+
25
+ if self.enable_amp:
26
+ self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
27
+ self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
28
+ else:
29
+ self.amp_ctx = NullCtx()
30
+ self.scaler = None
31
+
32
+ self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
33
+ self.grad_clip = grad_clip
34
+ self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
35
+ self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
36
+
37
+ self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
38
+
39
+ def backward_clip_step(
40
+ self, stepping: bool, loss: torch.Tensor,
41
+ ) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
42
+ # backward
43
+ loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
44
+ orig_norm = scaler_sc = None
45
+ if self.scaler is not None:
46
+ self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
47
+ else:
48
+ loss.backward(retain_graph=False, create_graph=False)
49
+
50
+ if stepping:
51
+ if self.scaler is not None: self.scaler.unscale_(self.optimizer)
52
+ if self.early_clipping:
53
+ orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
54
+
55
+ if self.scaler is not None:
56
+ self.scaler.step(self.optimizer)
57
+ scaler_sc: float = self.scaler.get_scale()
58
+ if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
59
+ self.scaler.update(new_scale=32768.)
60
+ else:
61
+ self.scaler.update()
62
+ try:
63
+ scaler_sc = float(math.log2(scaler_sc))
64
+ except Exception as e:
65
+ print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
66
+ raise e
67
+ else:
68
+ self.optimizer.step()
69
+
70
+ if self.late_clipping:
71
+ orig_norm = self.optimizer.global_grad_norm
72
+
73
+ self.optimizer.zero_grad(set_to_none=True)
74
+
75
+ return orig_norm, scaler_sc
76
+
77
+ def state_dict(self):
78
+ return {
79
+ 'optimizer': self.optimizer.state_dict()
80
+ } if self.scaler is None else {
81
+ 'scaler': self.scaler.state_dict(),
82
+ 'optimizer': self.optimizer.state_dict()
83
+ }
84
+
85
+ def load_state_dict(self, state, strict=True):
86
+ if self.scaler is not None:
87
+ try: self.scaler.load_state_dict(state['scaler'])
88
+ except Exception as e: print(f'[fp16 load_state_dict err] {e}')
89
+ self.optimizer.load_state_dict(state['optimizer'])
utils/arg_util.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import OrderedDict
9
+ from typing import Optional, Union
10
+
11
+ import numpy as np
12
+ import torch
13
+
14
+ try:
15
+ from tap import Tap
16
+ except ImportError as e:
17
+ print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
18
+ print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
19
+ time.sleep(5)
20
+ raise e
21
+
22
+ import dist
23
+
24
+
25
+ class Args(Tap):
26
+ data_path: str = '/path/to/imagenet'
27
+ exp_name: str = 'text'
28
+
29
+ # VAE
30
+ vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
31
+ # VAR
32
+ tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
33
+ depth: int = 16 # VAR depth
34
+ # VAR initialization
35
+ ini: float = -1 # -1: automated model parameter initialization
36
+ hd: float = 0.02 # head.w *= hd
37
+ aln: float = 0.5 # the multiplier of ada_lin.w's initialization
38
+ alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
39
+ # VAR optimization
40
+ fp16: int = 0 # 1: using fp16, 2: bf16
41
+ tblr: float = 1e-4 # base lr
42
+ tlr: float = None # lr = base lr * (bs / 256)
43
+ twd: float = 0.05 # initial wd
44
+ twde: float = 0 # final wd, =twde or twd
45
+ tclip: float = 2. # <=0 for not using grad clip
46
+ ls: float = 0.0 # label smooth
47
+
48
+ bs: int = 768 # global batch size
49
+ batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
50
+ glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
51
+ ac: int = 1 # gradient accumulation
52
+
53
+ ep: int = 250
54
+ wp: float = 0
55
+ wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
56
+ wpe: float = 0.01 # final lr ratio at the end of training
57
+ sche: str = 'lin0' # lr schedule
58
+
59
+ opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
60
+ afuse: bool = True # fused adamw
61
+
62
+ # other hps
63
+ saln: bool = False # whether to use shared adaln
64
+ anorm: bool = True # whether to use L2 normalized attention
65
+ fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
66
+
67
+ # data
68
+ pn: str = '1_2_3_4_5_6_8_10_13_16'
69
+ patch_size: int = 16
70
+ patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
71
+ resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
72
+
73
+ data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
74
+ mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
75
+ hflip: bool = False # augmentation: horizontal flip
76
+ workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
77
+
78
+ # progressive training
79
+ pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
80
+ pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
81
+ pgwp: float = 0 # num of warmup epochs at each progressive stage
82
+
83
+ # would be automatically set in runtime
84
+ cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
85
+ branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
86
+ commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
87
+ commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
88
+ acc_mean: float = None # [automatically set; don't specify this]
89
+ acc_tail: float = None # [automatically set; don't specify this]
90
+ L_mean: float = None # [automatically set; don't specify this]
91
+ L_tail: float = None # [automatically set; don't specify this]
92
+ vacc_mean: float = None # [automatically set; don't specify this]
93
+ vacc_tail: float = None # [automatically set; don't specify this]
94
+ vL_mean: float = None # [automatically set; don't specify this]
95
+ vL_tail: float = None # [automatically set; don't specify this]
96
+ grad_norm: float = None # [automatically set; don't specify this]
97
+ cur_lr: float = None # [automatically set; don't specify this]
98
+ cur_wd: float = None # [automatically set; don't specify this]
99
+ cur_it: str = '' # [automatically set; don't specify this]
100
+ cur_ep: str = '' # [automatically set; don't specify this]
101
+ remain_time: str = '' # [automatically set; don't specify this]
102
+ finish_time: str = '' # [automatically set; don't specify this]
103
+
104
+ # environment
105
+ local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
106
+ tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
107
+ log_txt_path: str = '...' # [automatically set; don't specify this]
108
+ last_ckpt_path: str = '...' # [automatically set; don't specify this]
109
+
110
+ tf32: bool = True # whether to use TensorFloat32
111
+ device: str = 'cpu' # [automatically set; don't specify this]
112
+ seed: int = None # seed
113
+ def seed_everything(self, benchmark: bool):
114
+ torch.backends.cudnn.enabled = True
115
+ torch.backends.cudnn.benchmark = benchmark
116
+ if self.seed is None:
117
+ torch.backends.cudnn.deterministic = False
118
+ else:
119
+ torch.backends.cudnn.deterministic = True
120
+ seed = self.seed * dist.get_world_size() + dist.get_rank()
121
+ os.environ['PYTHONHASHSEED'] = str(seed)
122
+ random.seed(seed)
123
+ np.random.seed(seed)
124
+ torch.manual_seed(seed)
125
+ if torch.cuda.is_available():
126
+ torch.cuda.manual_seed(seed)
127
+ torch.cuda.manual_seed_all(seed)
128
+ same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
129
+ def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
130
+ if self.seed is None:
131
+ return None
132
+ g = torch.Generator()
133
+ g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
134
+ return g
135
+
136
+ local_debug: bool = 'KEVIN_LOCAL' in os.environ
137
+ dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
138
+
139
+ def compile_model(self, m, fast):
140
+ if fast == 0 or self.local_debug:
141
+ return m
142
+ return torch.compile(m, mode={
143
+ 1: 'reduce-overhead',
144
+ 2: 'max-autotune',
145
+ 3: 'default',
146
+ }[fast]) if hasattr(torch, 'compile') else m
147
+
148
+ def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
149
+ d = (OrderedDict if key_ordered else dict)()
150
+ # self.as_dict() would contain methods, but we only need variables
151
+ for k in self.class_variables.keys():
152
+ if k not in {'device'}: # these are not serializable
153
+ d[k] = getattr(self, k)
154
+ return d
155
+
156
+ def load_state_dict(self, d: Union[OrderedDict, dict, str]):
157
+ if isinstance(d, str): # for compatibility with old version
158
+ d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
159
+ for k in d.keys():
160
+ try:
161
+ setattr(self, k, d[k])
162
+ except Exception as e:
163
+ print(f'k={k}, v={d[k]}')
164
+ raise e
165
+
166
+ @staticmethod
167
+ def set_tf32(tf32: bool):
168
+ if torch.cuda.is_available():
169
+ torch.backends.cudnn.allow_tf32 = bool(tf32)
170
+ torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
171
+ if hasattr(torch, 'set_float32_matmul_precision'):
172
+ torch.set_float32_matmul_precision('high' if tf32 else 'highest')
173
+ print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
174
+ print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
175
+ print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
176
+
177
+ def dump_log(self):
178
+ if not dist.is_local_master():
179
+ return
180
+ if '1/' in self.cur_ep: # first time to dump log
181
+ with open(self.log_txt_path, 'w') as fp:
182
+ json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
183
+ fp.write('\n')
184
+
185
+ log_dict = {}
186
+ for k, v in {
187
+ 'it': self.cur_it, 'ep': self.cur_ep,
188
+ 'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
189
+ 'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
190
+ 'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
191
+ 'remain_time': self.remain_time, 'finish_time': self.finish_time,
192
+ }.items():
193
+ if hasattr(v, 'item'): v = v.item()
194
+ log_dict[k] = v
195
+ with open(self.log_txt_path, 'a') as fp:
196
+ fp.write(f'{log_dict}\n')
197
+
198
+ def __str__(self):
199
+ s = []
200
+ for k in self.class_variables.keys():
201
+ if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
202
+ s.append(f' {k:20s}: {getattr(self, k)}')
203
+ s = '\n'.join(s)
204
+ return f'{{\n{s}\n}}\n'
205
+
206
+
207
+ def init_dist_and_get_args():
208
+ for i in range(len(sys.argv)):
209
+ if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
210
+ del sys.argv[i]
211
+ break
212
+ args = Args(explicit_bool=True).parse_args(known_only=True)
213
+ if args.local_debug:
214
+ args.pn = '1_2_3'
215
+ args.seed = 1
216
+ args.aln = 1e-2
217
+ args.alng = 1e-5
218
+ args.saln = False
219
+ args.afuse = False
220
+ args.pg = 0.8
221
+ args.pg0 = 1
222
+ else:
223
+ if args.data_path == '/path/to/imagenet':
224
+ raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
225
+
226
+ # warn args.extra_args
227
+ if len(args.extra_args) > 0:
228
+ print(f'======================================================================================')
229
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
230
+ print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
231
+ print(f'======================================================================================\n\n')
232
+
233
+ # init torch distributed
234
+ from utils import misc
235
+ os.makedirs(args.local_out_dir_path, exist_ok=True)
236
+ misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
237
+
238
+ # set env
239
+ args.set_tf32(args.tf32)
240
+ args.seed_everything(benchmark=args.pg == 0)
241
+
242
+ # update args: data loading
243
+ args.device = dist.get_device()
244
+ if args.pn == '256':
245
+ args.pn = '1_2_3_4_5_6_8_10_13_16'
246
+ elif args.pn == '512':
247
+ args.pn = '1_2_3_4_6_9_13_18_24_32'
248
+ elif args.pn == '1024':
249
+ args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
250
+ args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
251
+ args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
252
+ args.data_load_reso = max(args.resos)
253
+
254
+ # update args: bs and lr
255
+ bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
256
+ args.batch_size = bs_per_gpu
257
+ args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
258
+ args.workers = min(max(0, args.workers), args.batch_size)
259
+
260
+ args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
261
+ args.twde = args.twde or args.twd
262
+
263
+ if args.wp == 0:
264
+ args.wp = args.ep * 1/50
265
+
266
+ # update args: progressive training
267
+ if args.pgwp == 0:
268
+ args.pgwp = args.ep * 1/300
269
+ if args.pg > 0:
270
+ args.sche = f'lin{args.pg:g}'
271
+
272
+ # update args: paths
273
+ args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
274
+ args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
275
+ _reg_valid_name = re.compile(r'[^\w\-+,.]')
276
+ tb_name = _reg_valid_name.sub(
277
+ '_',
278
+ f'tb-VARd{args.depth}'
279
+ f'__pn{args.pn}'
280
+ f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
281
+ )
282
+ args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
283
+
284
+ return args
utils/data.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+
3
+ import PIL.Image as PImage
4
+ from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
5
+ from torchvision.transforms import InterpolationMode, transforms
6
+
7
+
8
+ def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
9
+ return x.add(x).add_(-1)
10
+
11
+
12
+ def build_dataset(
13
+ data_path: str, final_reso: int,
14
+ hflip=False, mid_reso=1.125,
15
+ ):
16
+ # build augmentations
17
+ mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
18
+ train_aug, val_aug = [
19
+ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
20
+ transforms.RandomCrop((final_reso, final_reso)),
21
+ transforms.ToTensor(), normalize_01_into_pm1,
22
+ ], [
23
+ transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
24
+ transforms.CenterCrop((final_reso, final_reso)),
25
+ transforms.ToTensor(), normalize_01_into_pm1,
26
+ ]
27
+ if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
28
+ train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
29
+
30
+ # build dataset
31
+ train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
32
+ val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
33
+ num_classes = 1000
34
+ print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
35
+ print_aug(train_aug, '[train]')
36
+ print_aug(val_aug, '[val]')
37
+
38
+ return num_classes, train_set, val_set
39
+
40
+
41
+ def pil_loader(path):
42
+ with open(path, 'rb') as f:
43
+ img: PImage.Image = PImage.open(f).convert('RGB')
44
+ return img
45
+
46
+
47
+ def print_aug(transform, label):
48
+ print(f'Transform {label} = ')
49
+ if hasattr(transform, 'transforms'):
50
+ for t in transform.transforms:
51
+ print(t)
52
+ else:
53
+ print(transform)
54
+ print('---------------------------\n')
utils/data_sampler.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data.sampler import Sampler
4
+
5
+
6
+ class EvalDistributedSampler(Sampler):
7
+ def __init__(self, dataset, num_replicas, rank):
8
+ seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
9
+ beg, end = seps[:-1], seps[1:]
10
+ beg, end = beg[rank], end[rank]
11
+ self.indices = tuple(range(beg, end))
12
+
13
+ def __iter__(self):
14
+ return iter(self.indices)
15
+
16
+ def __len__(self) -> int:
17
+ return len(self.indices)
18
+
19
+
20
+ class InfiniteBatchSampler(Sampler):
21
+ def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
22
+ self.dataset_len = dataset_len
23
+ self.batch_size = batch_size
24
+ self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
25
+ self.max_p = self.iters_per_ep * batch_size
26
+ self.fill_last = fill_last
27
+ self.shuffle = shuffle
28
+ self.epoch = start_ep
29
+ self.same_seed_for_all_ranks = seed_for_all_rank
30
+ self.indices = self.gener_indices()
31
+ self.start_ep, self.start_it = start_ep, start_it
32
+
33
+ def gener_indices(self):
34
+ if self.shuffle:
35
+ g = torch.Generator()
36
+ g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
37
+ indices = torch.randperm(self.dataset_len, generator=g).numpy()
38
+ else:
39
+ indices = torch.arange(self.dataset_len).numpy()
40
+
41
+ tails = self.batch_size - (self.dataset_len % self.batch_size)
42
+ if tails != self.batch_size and self.fill_last:
43
+ tails = indices[:tails]
44
+ np.random.shuffle(indices)
45
+ indices = np.concatenate((indices, tails))
46
+
47
+ # built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
48
+ # noinspection PyTypeChecker
49
+ return tuple(indices.tolist())
50
+
51
+ def __iter__(self):
52
+ self.epoch = self.start_ep
53
+ while True:
54
+ self.epoch += 1
55
+ p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
56
+ while p < self.max_p:
57
+ q = p + self.batch_size
58
+ yield self.indices[p:q]
59
+ p = q
60
+ if self.shuffle:
61
+ self.indices = self.gener_indices()
62
+
63
+ def __len__(self):
64
+ return self.iters_per_ep
65
+
66
+
67
+ class DistInfiniteBatchSampler(InfiniteBatchSampler):
68
+ def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
69
+ assert glb_batch_size % world_size == 0
70
+ self.world_size, self.rank = world_size, rank
71
+ self.dataset_len = dataset_len
72
+ self.glb_batch_size = glb_batch_size
73
+ self.batch_size = glb_batch_size // world_size
74
+
75
+ self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
76
+ self.fill_last = fill_last
77
+ self.shuffle = shuffle
78
+ self.repeated_aug = repeated_aug
79
+ self.epoch = start_ep
80
+ self.same_seed_for_all_ranks = same_seed_for_all_ranks
81
+ self.indices = self.gener_indices()
82
+ self.start_ep, self.start_it = start_ep, start_it
83
+
84
+ def gener_indices(self):
85
+ global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
86
+ # print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
87
+ if self.shuffle:
88
+ g = torch.Generator()
89
+ g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
90
+ global_indices = torch.randperm(self.dataset_len, generator=g)
91
+ if self.repeated_aug > 1:
92
+ global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
93
+ else:
94
+ global_indices = torch.arange(self.dataset_len)
95
+ filling = global_max_p - global_indices.shape[0]
96
+ if filling > 0 and self.fill_last:
97
+ global_indices = torch.cat((global_indices, global_indices[:filling]))
98
+ # global_indices = tuple(global_indices.numpy().tolist())
99
+
100
+ seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
101
+ local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
102
+ self.max_p = len(local_indices)
103
+ return local_indices
utils/lr_control.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from pprint import pformat
3
+ from typing import Tuple, List, Dict, Union
4
+
5
+ import torch.nn
6
+
7
+ import dist
8
+
9
+
10
+ def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
11
+ """Decay the learning rate with half-cycle cosine after warmup"""
12
+ wp_it = round(wp_it)
13
+
14
+ if cur_it < wp_it:
15
+ cur_lr = wp0 + (1-wp0) * cur_it / wp_it
16
+ else:
17
+ pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
18
+ rest = 1 - pasd # [1, 0]
19
+ if sche_type == 'cos':
20
+ cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
21
+ elif sche_type == 'lin':
22
+ T = 0.15; max_rest = 1-T
23
+ if pasd < T: cur_lr = 1
24
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
25
+ elif sche_type == 'lin0':
26
+ T = 0.05; max_rest = 1-T
27
+ if pasd < T: cur_lr = 1
28
+ else: cur_lr = wpe + (1-wpe) * rest / max_rest
29
+ elif sche_type == 'lin00':
30
+ cur_lr = wpe + (1-wpe) * rest
31
+ elif sche_type.startswith('lin'):
32
+ T = float(sche_type[3:]); max_rest = 1-T
33
+ wpe_mid = wpe + (1-wpe) * max_rest
34
+ wpe_mid = (1 + wpe_mid) / 2
35
+ if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
36
+ else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
37
+ elif sche_type == 'exp':
38
+ T = 0.15; max_rest = 1-T
39
+ if pasd < T: cur_lr = 1
40
+ else:
41
+ expo = (pasd-T) / max_rest * math.log(wpe)
42
+ cur_lr = math.exp(expo)
43
+ else:
44
+ raise NotImplementedError(f'unknown sche_type {sche_type}')
45
+
46
+ cur_lr *= peak_lr
47
+ pasd = cur_it / (max_it-1)
48
+ cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
49
+
50
+ inf = 1e6
51
+ min_lr, max_lr = inf, -1
52
+ min_wd, max_wd = inf, -1
53
+ for param_group in optimizer.param_groups:
54
+ param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
55
+ max_lr = max(max_lr, param_group['lr'])
56
+ min_lr = min(min_lr, param_group['lr'])
57
+
58
+ param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
59
+ max_wd = max(max_wd, param_group['weight_decay'])
60
+ if param_group['weight_decay'] > 0:
61
+ min_wd = min(min_wd, param_group['weight_decay'])
62
+
63
+ if min_lr == inf: min_lr = -1
64
+ if min_wd == inf: min_wd = -1
65
+ return min_lr, max_lr, min_wd, max_wd
66
+
67
+
68
+ def filter_params(model, nowd_keys=()) -> Tuple[
69
+ List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
70
+ ]:
71
+ para_groups, para_groups_dbg = {}, {}
72
+ names, paras = [], []
73
+ names_no_grad = []
74
+ count, numel = 0, 0
75
+ for name, para in model.named_parameters():
76
+ name = name.replace('_fsdp_wrapped_module.', '')
77
+ if not para.requires_grad:
78
+ names_no_grad.append(name)
79
+ continue # frozen weights
80
+ count += 1
81
+ numel += para.numel()
82
+ names.append(name)
83
+ paras.append(para)
84
+
85
+ if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
86
+ cur_wd_sc, group_name = 0., 'ND'
87
+ else:
88
+ cur_wd_sc, group_name = 1., 'D'
89
+ cur_lr_sc = 1.
90
+ if group_name not in para_groups:
91
+ para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
92
+ para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
93
+ para_groups[group_name]['params'].append(para)
94
+ para_groups_dbg[group_name]['params'].append(name)
95
+
96
+ for g in para_groups_dbg.values():
97
+ g['params'] = pformat(', '.join(g['params']), width=200)
98
+
99
+ print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
100
+
101
+ for rk in range(dist.get_world_size()):
102
+ dist.barrier()
103
+ if dist.get_rank() == rk:
104
+ print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
105
+ print('')
106
+
107
+ assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
108
+ return names, paras, list(para_groups.values())
utils/misc.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import functools
3
+ import glob
4
+ import os
5
+ import subprocess
6
+ import sys
7
+ import time
8
+ from collections import defaultdict, deque
9
+ from typing import Iterator, List, Tuple
10
+
11
+ import numpy as np
12
+ import pytz
13
+ import torch
14
+ import torch.distributed as tdist
15
+
16
+ import dist
17
+ from utils import arg_util
18
+
19
+ os_system = functools.partial(subprocess.call, shell=True)
20
+ def echo(info):
21
+ os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
22
+ def os_system_get_stdout(cmd):
23
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
24
+ def os_system_get_stdout_stderr(cmd):
25
+ cnt = 0
26
+ while True:
27
+ try:
28
+ sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
29
+ except subprocess.TimeoutExpired:
30
+ cnt += 1
31
+ print(f'[fetch free_port file] timeout cnt={cnt}')
32
+ else:
33
+ return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
34
+
35
+
36
+ def time_str(fmt='[%m-%d %H:%M:%S]'):
37
+ return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
38
+
39
+
40
+ def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
41
+ try:
42
+ dist.initialize(fork=False, timeout=timeout)
43
+ dist.barrier()
44
+ except RuntimeError:
45
+ print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
46
+ time.sleep(10)
47
+
48
+ if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
49
+ _change_builtin_print(dist.is_local_master())
50
+ if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
51
+ sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
52
+
53
+
54
+ def _change_builtin_print(is_master):
55
+ import builtins as __builtin__
56
+
57
+ builtin_print = __builtin__.print
58
+ if type(builtin_print) != type(open):
59
+ return
60
+
61
+ def prt(*args, **kwargs):
62
+ force = kwargs.pop('force', False)
63
+ clean = kwargs.pop('clean', False)
64
+ deeper = kwargs.pop('deeper', False)
65
+ if is_master or force:
66
+ if not clean:
67
+ f_back = sys._getframe().f_back
68
+ if deeper and f_back.f_back is not None:
69
+ f_back = f_back.f_back
70
+ file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
71
+ builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
72
+ else:
73
+ builtin_print(*args, **kwargs)
74
+
75
+ __builtin__.print = prt
76
+
77
+
78
+ class SyncPrint(object):
79
+ def __init__(self, local_output_dir, sync_stdout=True):
80
+ self.sync_stdout = sync_stdout
81
+ self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
82
+ fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
83
+ existing = os.path.exists(fname)
84
+ self.file_stream = open(fname, 'a')
85
+ if existing:
86
+ self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
87
+ self.file_stream.flush()
88
+ self.enabled = True
89
+
90
+ def write(self, message):
91
+ self.terminal_stream.write(message)
92
+ self.file_stream.write(message)
93
+
94
+ def flush(self):
95
+ self.terminal_stream.flush()
96
+ self.file_stream.flush()
97
+
98
+ def close(self):
99
+ if not self.enabled:
100
+ return
101
+ self.enabled = False
102
+ self.file_stream.flush()
103
+ self.file_stream.close()
104
+ if self.sync_stdout:
105
+ sys.stdout = self.terminal_stream
106
+ sys.stdout.flush()
107
+ else:
108
+ sys.stderr = self.terminal_stream
109
+ sys.stderr.flush()
110
+
111
+ def __del__(self):
112
+ self.close()
113
+
114
+
115
+ class DistLogger(object):
116
+ def __init__(self, lg, verbose):
117
+ self._lg, self._verbose = lg, verbose
118
+
119
+ @staticmethod
120
+ def do_nothing(*args, **kwargs):
121
+ pass
122
+
123
+ def __getattr__(self, attr: str):
124
+ return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
125
+
126
+
127
+ class TensorboardLogger(object):
128
+ def __init__(self, log_dir, filename_suffix):
129
+ try: import tensorflow_io as tfio
130
+ except: pass
131
+ from torch.utils.tensorboard import SummaryWriter
132
+ self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
133
+ self.step = 0
134
+
135
+ def set_step(self, step=None):
136
+ if step is not None:
137
+ self.step = step
138
+ else:
139
+ self.step += 1
140
+
141
+ def update(self, head='scalar', step=None, **kwargs):
142
+ for k, v in kwargs.items():
143
+ if v is None:
144
+ continue
145
+ # assert isinstance(v, (float, int)), type(v)
146
+ if step is None: # iter wise
147
+ it = self.step
148
+ if it == 0 or (it + 1) % 500 == 0:
149
+ if hasattr(v, 'item'): v = v.item()
150
+ self.writer.add_scalar(f'{head}/{k}', v, it)
151
+ else: # epoch wise
152
+ if hasattr(v, 'item'): v = v.item()
153
+ self.writer.add_scalar(f'{head}/{k}', v, step)
154
+
155
+ def log_tensor_as_distri(self, tag, tensor1d, step=None):
156
+ if step is None: # iter wise
157
+ step = self.step
158
+ loggable = step == 0 or (step + 1) % 500 == 0
159
+ else: # epoch wise
160
+ loggable = True
161
+ if loggable:
162
+ try:
163
+ self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
164
+ except Exception as e:
165
+ print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
166
+
167
+ def log_image(self, tag, img_chw, step=None):
168
+ if step is None: # iter wise
169
+ step = self.step
170
+ loggable = step == 0 or (step + 1) % 500 == 0
171
+ else: # epoch wise
172
+ loggable = True
173
+ if loggable:
174
+ self.writer.add_image(tag, img_chw, step, dataformats='CHW')
175
+
176
+ def flush(self):
177
+ self.writer.flush()
178
+
179
+ def close(self):
180
+ self.writer.close()
181
+
182
+
183
+ class SmoothedValue(object):
184
+ """Track a series of values and provide access to smoothed values over a
185
+ window or the global series average.
186
+ """
187
+
188
+ def __init__(self, window_size=30, fmt=None):
189
+ if fmt is None:
190
+ fmt = "{median:.4f} ({global_avg:.4f})"
191
+ self.deque = deque(maxlen=window_size)
192
+ self.total = 0.0
193
+ self.count = 0
194
+ self.fmt = fmt
195
+
196
+ def update(self, value, n=1):
197
+ self.deque.append(value)
198
+ self.count += n
199
+ self.total += value * n
200
+
201
+ def synchronize_between_processes(self):
202
+ """
203
+ Warning: does not synchronize the deque!
204
+ """
205
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
206
+ tdist.barrier()
207
+ tdist.all_reduce(t)
208
+ t = t.tolist()
209
+ self.count = int(t[0])
210
+ self.total = t[1]
211
+
212
+ @property
213
+ def median(self):
214
+ return np.median(self.deque) if len(self.deque) else 0
215
+
216
+ @property
217
+ def avg(self):
218
+ return sum(self.deque) / (len(self.deque) or 1)
219
+
220
+ @property
221
+ def global_avg(self):
222
+ return self.total / (self.count or 1)
223
+
224
+ @property
225
+ def max(self):
226
+ return max(self.deque)
227
+
228
+ @property
229
+ def value(self):
230
+ return self.deque[-1] if len(self.deque) else 0
231
+
232
+ def time_preds(self, counts) -> Tuple[float, str, str]:
233
+ remain_secs = counts * self.median
234
+ return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
235
+
236
+ def __str__(self):
237
+ return self.fmt.format(
238
+ median=self.median,
239
+ avg=self.avg,
240
+ global_avg=self.global_avg,
241
+ max=self.max,
242
+ value=self.value)
243
+
244
+
245
+ class MetricLogger(object):
246
+ def __init__(self, delimiter=' '):
247
+ self.meters = defaultdict(SmoothedValue)
248
+ self.delimiter = delimiter
249
+ self.iter_end_t = time.time()
250
+ self.log_iters = []
251
+
252
+ def update(self, **kwargs):
253
+ for k, v in kwargs.items():
254
+ if v is None:
255
+ continue
256
+ if hasattr(v, 'item'): v = v.item()
257
+ # assert isinstance(v, (float, int)), type(v)
258
+ assert isinstance(v, (float, int))
259
+ self.meters[k].update(v)
260
+
261
+ def __getattr__(self, attr):
262
+ if attr in self.meters:
263
+ return self.meters[attr]
264
+ if attr in self.__dict__:
265
+ return self.__dict__[attr]
266
+ raise AttributeError("'{}' object has no attribute '{}'".format(
267
+ type(self).__name__, attr))
268
+
269
+ def __str__(self):
270
+ loss_str = []
271
+ for name, meter in self.meters.items():
272
+ if len(meter.deque):
273
+ loss_str.append(
274
+ "{}: {}".format(name, str(meter))
275
+ )
276
+ return self.delimiter.join(loss_str)
277
+
278
+ def synchronize_between_processes(self):
279
+ for meter in self.meters.values():
280
+ meter.synchronize_between_processes()
281
+
282
+ def add_meter(self, name, meter):
283
+ self.meters[name] = meter
284
+
285
+ def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
286
+ self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
287
+ self.log_iters.add(start_it)
288
+ if not header:
289
+ header = ''
290
+ start_time = time.time()
291
+ self.iter_end_t = time.time()
292
+ self.iter_time = SmoothedValue(fmt='{avg:.4f}')
293
+ self.data_time = SmoothedValue(fmt='{avg:.4f}')
294
+ space_fmt = ':' + str(len(str(max_iters))) + 'd'
295
+ log_msg = [
296
+ header,
297
+ '[{0' + space_fmt + '}/{1}]',
298
+ 'eta: {eta}',
299
+ '{meters}',
300
+ 'time: {time}',
301
+ 'data: {data}'
302
+ ]
303
+ log_msg = self.delimiter.join(log_msg)
304
+
305
+ if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
306
+ for i in range(start_it, max_iters):
307
+ obj = next(itrt)
308
+ self.data_time.update(time.time() - self.iter_end_t)
309
+ yield i, obj
310
+ self.iter_time.update(time.time() - self.iter_end_t)
311
+ if i in self.log_iters:
312
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
313
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
314
+ print(log_msg.format(
315
+ i, max_iters, eta=eta_string,
316
+ meters=str(self),
317
+ time=str(self.iter_time), data=str(self.data_time)), flush=True)
318
+ self.iter_end_t = time.time()
319
+ else:
320
+ if isinstance(itrt, int): itrt = range(itrt)
321
+ for i, obj in enumerate(itrt):
322
+ self.data_time.update(time.time() - self.iter_end_t)
323
+ yield i, obj
324
+ self.iter_time.update(time.time() - self.iter_end_t)
325
+ if i in self.log_iters:
326
+ eta_seconds = self.iter_time.global_avg * (max_iters - i)
327
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
328
+ print(log_msg.format(
329
+ i, max_iters, eta=eta_string,
330
+ meters=str(self),
331
+ time=str(self.iter_time), data=str(self.data_time)), flush=True)
332
+ self.iter_end_t = time.time()
333
+
334
+ total_time = time.time() - start_time
335
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
336
+ print('{} Total time: {} ({:.3f} s / it)'.format(
337
+ header, total_time_str, total_time / max_iters), flush=True)
338
+
339
+
340
+ def glob_with_latest_modified_first(pattern, recursive=False):
341
+ return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
342
+
343
+
344
+ def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
345
+ info = []
346
+ file = os.path.join(args.local_out_dir_path, pattern)
347
+ all_ckpt = glob_with_latest_modified_first(file)
348
+ if len(all_ckpt) == 0:
349
+ info.append(f'[auto_resume] no ckpt found @ {file}')
350
+ info.append(f'[auto_resume quit]')
351
+ return info, 0, 0, {}, {}
352
+ else:
353
+ info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
354
+ ckpt = torch.load(all_ckpt[0], map_location='cpu')
355
+ ep, it = ckpt['epoch'], ckpt['iter']
356
+ info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
357
+ return info, ep, it, ckpt['trainer'], ckpt['args']
358
+
359
+
360
+ def create_npz_from_sample_folder(sample_folder: str):
361
+ """
362
+ Builds a single .npz file from a folder of .png samples. Refer to DiT.
363
+ """
364
+ import os, glob
365
+ import numpy as np
366
+ from tqdm import tqdm
367
+ from PIL import Image
368
+
369
+ samples = []
370
+ pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
371
+ assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
372
+ for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
373
+ with Image.open(png) as sample_pil:
374
+ sample_np = np.asarray(sample_pil).astype(np.uint8)
375
+ samples.append(sample_np)
376
+ samples = np.stack(samples)
377
+ assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
378
+ npz_path = f'{sample_folder}.npz'
379
+ np.savez(npz_path, arr_0=samples)
380
+ print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
381
+ return npz_path