Spaces:
Running
Running
AmitIsraeli
commited on
Commit
•
64bf706
1
Parent(s):
3aaab28
Add model and infrance app
Browse files- VARtext_v1.pth +3 -0
- app.py +236 -4
- dist.py +211 -0
- models/__init__.py +39 -0
- models/basic_vae.py +226 -0
- models/basic_var.py +174 -0
- models/helpers.py +59 -0
- models/quant.py +281 -0
- models/var.py +360 -0
- models/vqvae.py +95 -0
- utils/amp_sc.py +89 -0
- utils/arg_util.py +284 -0
- utils/data.py +54 -0
- utils/data_sampler.py +103 -0
- utils/lr_control.py +108 -0
- utils/misc.py +381 -0
VARtext_v1.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:bbaa03cee25cb0abba7ac5d476f6b800b78dda29c6cb2773a11b584022585fcf
|
3 |
+
size 1963751390
|
app.py
CHANGED
@@ -1,7 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from models import VQVAE, build_vae_var
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from transformers import AutoTokenizer, SiglipTextModel
|
6 |
+
from peft import LoraConfig, get_peft_model
|
7 |
+
import random
|
8 |
+
from torchvision.transforms import ToPILImage
|
9 |
+
import numpy as np
|
10 |
+
from moviepy.editor import ImageSequenceClip
|
11 |
+
import random
|
12 |
import gradio as gr
|
13 |
+
import tempfile
|
14 |
+
import os
|
15 |
|
16 |
+
class SimpleAdapter(nn.Module):
|
17 |
+
def __init__(self, input_dim=512, hidden_dim=1024, out_dim=1024):
|
18 |
+
super(SimpleAdapter, self).__init__()
|
19 |
+
self.layer1 = nn.Linear(input_dim, hidden_dim)
|
20 |
+
self.norm0 = nn.LayerNorm(input_dim)
|
21 |
+
self.activation1 = nn.GELU()
|
22 |
+
self.layer2 = nn.Linear(hidden_dim, out_dim)
|
23 |
+
self.norm2 = nn.LayerNorm(out_dim)
|
24 |
+
self._initialize_weights()
|
25 |
|
26 |
+
def _initialize_weights(self):
|
27 |
+
for m in self.modules():
|
28 |
+
if isinstance(m, nn.Linear):
|
29 |
+
nn.init.xavier_uniform_(m.weight, gain=0.001)
|
30 |
+
nn.init.zeros_(m.bias)
|
31 |
+
elif isinstance(m, nn.LayerNorm):
|
32 |
+
nn.init.ones_(m.weight)
|
33 |
+
nn.init.zeros_(m.bias)
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
x = self.norm0(x)
|
37 |
+
x = self.layer1(x)
|
38 |
+
x = self.activation1(x)
|
39 |
+
x = self.layer2(x)
|
40 |
+
x = self.norm2(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
class InrenceTextVAR(nn.Module):
|
44 |
+
def __init__(self, pl_checkpoint=None, start_class_id=578, hugging_face_token=None, siglip_model='google/siglip-base-patch16-224', device="cpu", MODEL_DEPTH=16):
|
45 |
+
super(InrenceTextVAR, self).__init__()
|
46 |
+
self.device = device
|
47 |
+
self.class_id = start_class_id
|
48 |
+
# Define layers
|
49 |
+
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
|
50 |
+
self.vae, self.var = build_vae_var(
|
51 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
52 |
+
device=device, patch_nums=patch_nums,
|
53 |
+
num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
|
54 |
+
)
|
55 |
+
self.text_processor = AutoTokenizer.from_pretrained(siglip_model, token=hugging_face_token)
|
56 |
+
self.siglip_text_encoder = SiglipTextModel.from_pretrained(siglip_model, token=hugging_face_token).to(device)
|
57 |
+
self.adapter = SimpleAdapter(
|
58 |
+
input_dim=self.siglip_text_encoder.config.hidden_size,
|
59 |
+
out_dim=self.var.C # Ensure dimensional consistency
|
60 |
+
).to(device)
|
61 |
+
self.apply_lora_to_var()
|
62 |
+
if pl_checkpoint is not None:
|
63 |
+
state_dict = torch.load(pl_checkpoint, map_location="cpu")['state_dict']
|
64 |
+
var_state_dict = {k[len('var.'):]: v for k, v in state_dict.items() if k.startswith('var.')}
|
65 |
+
vae_state_dict = {k[len('vae.'):]: v for k, v in state_dict.items() if k.startswith('vae.')}
|
66 |
+
adapter_state_dict = {k[len('adapter.'):]: v for k, v in state_dict.items() if k.startswith('adapter.')}
|
67 |
+
self.var.load_state_dict(var_state_dict)
|
68 |
+
self.vae.load_state_dict(vae_state_dict)
|
69 |
+
self.adapter.load_state_dict(adapter_state_dict)
|
70 |
+
del self.vae.encoder
|
71 |
+
|
72 |
+
def apply_lora_to_var(self):
|
73 |
+
"""
|
74 |
+
Applies LoRA (Low-Rank Adaptation) to the VAR model.
|
75 |
+
"""
|
76 |
+
def find_linear_module_names(model):
|
77 |
+
linear_module_names = []
|
78 |
+
for name, module in model.named_modules():
|
79 |
+
if isinstance(module, nn.Linear):
|
80 |
+
linear_module_names.append(name)
|
81 |
+
return linear_module_names
|
82 |
+
|
83 |
+
linear_module_names = find_linear_module_names(self.var)
|
84 |
+
|
85 |
+
lora_config = LoraConfig(
|
86 |
+
r=8,
|
87 |
+
lora_alpha=32,
|
88 |
+
target_modules=linear_module_names,
|
89 |
+
lora_dropout=0.05,
|
90 |
+
bias="none",
|
91 |
+
)
|
92 |
+
|
93 |
+
self.var = get_peft_model(self.var, lora_config)
|
94 |
+
|
95 |
+
@torch.no_grad()
|
96 |
+
def generate_image(self, text, beta=1, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|
97 |
+
if seed is None:
|
98 |
+
seed = random.randint(0, 2**32 - 1)
|
99 |
+
inputs = self.text_processor([text], padding="max_length", return_tensors="pt").to(self.device)
|
100 |
+
outputs = self.siglip_text_encoder(**inputs)
|
101 |
+
pooled_output = outputs.pooler_output # pooled (EOS token) states
|
102 |
+
pooled_output = F.normalize(pooled_output, p=2, dim=-1) # Normalize delta condition
|
103 |
+
cond_delta = F.normalize(pooled_output, p=2, dim=-1).to(self.device) # Use correct device
|
104 |
+
cond_delta = self.adapter(cond_delta)
|
105 |
+
cond_delta = F.normalize(cond_delta, p=2, dim=-1) # Normalize delta condition
|
106 |
+
generated_images = self.var.autoregressive_infer_cfg(
|
107 |
+
B=1,
|
108 |
+
label_B=self.class_id,
|
109 |
+
delta_condition=cond_delta[:1],
|
110 |
+
beta=beta,
|
111 |
+
alpha=1,
|
112 |
+
top_k=top_k,
|
113 |
+
top_p=top_p,
|
114 |
+
more_smooth=more_smooth,
|
115 |
+
g_seed=seed
|
116 |
+
)
|
117 |
+
image = ToPILImage()(generated_images[0].cpu())
|
118 |
+
return image
|
119 |
+
|
120 |
+
@torch.no_grad()
|
121 |
+
def generate_video(self, text, start_beta, target_beta, fps, length, top_k=0, top_p=0.9, seed=None,
|
122 |
+
more_smooth=False,
|
123 |
+
output_filename='output_video.mp4'):
|
124 |
+
|
125 |
+
if seed is None:
|
126 |
+
seed = random.randint(0, 2 ** 32 - 1)
|
127 |
+
|
128 |
+
num_frames = int(fps * length)
|
129 |
+
images = []
|
130 |
+
|
131 |
+
# Define an easing function for smoother interpolation
|
132 |
+
def ease_in_out(t):
|
133 |
+
return t * t * (3 - 2 * t)
|
134 |
+
|
135 |
+
# Generate t values between 0 and 1
|
136 |
+
t_values = np.linspace(0, 1, num_frames)
|
137 |
+
# Apply the easing function
|
138 |
+
eased_t_values = ease_in_out(t_values)
|
139 |
+
# Interpolate beta values using the eased t values
|
140 |
+
beta_values = start_beta + (target_beta - start_beta) * eased_t_values
|
141 |
+
|
142 |
+
for beta in beta_values:
|
143 |
+
image = self.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
|
144 |
+
images.append(np.array(image))
|
145 |
+
|
146 |
+
# Create a video from images
|
147 |
+
clip = ImageSequenceClip(images, fps=fps)
|
148 |
+
clip.write_videofile(output_filename, codec='libx264')
|
149 |
+
|
150 |
+
if __name__ == '__main__':
|
151 |
+
|
152 |
+
# Initialize the model
|
153 |
+
checkpoint = 'VARtext_v1.pth' # Replace with your actual checkpoint path
|
154 |
+
device = 'cpu' if not torch.cuda.is_available() else 'cuda'
|
155 |
+
state_dict = torch.load(checkpoint, map_location="cpu")
|
156 |
+
model = InrenceTextVAR(device=device)
|
157 |
+
model.load_state_dict(state_dict)
|
158 |
+
model.to(device)
|
159 |
+
|
160 |
+
def generate_image_gradio(text, beta=1.0, seed=None, more_smooth=False, top_k=0, top_p=0.9):
|
161 |
+
print(f"Generating image for text: {text}\n"
|
162 |
+
f"beta: {beta}\n"
|
163 |
+
f"seed: {seed}\n"
|
164 |
+
f"more_smooth: {more_smooth}\n"
|
165 |
+
f"top_k: {top_k}\n"
|
166 |
+
f"top_p: {top_p}\n")
|
167 |
+
image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=int(top_k), top_p=top_p)
|
168 |
+
return image
|
169 |
+
|
170 |
+
def generate_video_gradio(text, start_beta=1.0, target_beta=1.0, fps=10, length=5.0, top_k=0, top_p=0.9, seed=None, more_smooth=False, progress=gr.Progress()):
|
171 |
+
print(f"Generating video for text: {text}\n"
|
172 |
+
f"start_beta: {start_beta}\n"
|
173 |
+
f"target_beta: {target_beta}\n"
|
174 |
+
f"seed: {seed}\n"
|
175 |
+
f"more_smooth: {more_smooth}\n"
|
176 |
+
f"top_k: {top_k}\n"
|
177 |
+
f"top_p: {top_p}"
|
178 |
+
f"fps: {fps}\n"
|
179 |
+
f"length: {length}\n")
|
180 |
+
with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as tmpfile:
|
181 |
+
output_filename = tmpfile.name
|
182 |
+
num_frames = int(fps * length)
|
183 |
+
beta_values = np.linspace(start_beta, target_beta, num_frames)
|
184 |
+
images = []
|
185 |
+
|
186 |
+
for i, beta in enumerate(beta_values):
|
187 |
+
image = model.generate_image(text, beta=beta, seed=seed, more_smooth=more_smooth, top_k=top_k, top_p=top_p)
|
188 |
+
images.append(np.array(image))
|
189 |
+
# Update progress
|
190 |
+
progress((i + 1) / num_frames)
|
191 |
+
# Yield the frame image to update the GUI
|
192 |
+
yield image, gr.update()
|
193 |
+
|
194 |
+
# After generating all frames, create the video
|
195 |
+
clip = ImageSequenceClip(images, fps=fps)
|
196 |
+
clip.write_videofile(output_filename, codec='libx264')
|
197 |
+
|
198 |
+
# Yield the final video output
|
199 |
+
yield gr.update(), output_filename
|
200 |
+
|
201 |
+
with gr.Blocks() as demo:
|
202 |
+
gr.Markdown("# Text to Image/Video Generator")
|
203 |
+
with gr.Tab("Generate Image"):
|
204 |
+
text_input = gr.Textbox(label="Input Text")
|
205 |
+
beta_input = gr.Slider(label="Beta", minimum=0.0, maximum=2.5, step=0.05, value=1.0)
|
206 |
+
seed_input = gr.Number(label="Seed", value=None)
|
207 |
+
more_smooth_input = gr.Checkbox(label="More Smooth", value=False)
|
208 |
+
top_k_input = gr.Number(label="Top K", value=0)
|
209 |
+
top_p_input = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
|
210 |
+
generate_button = gr.Button("Generate Image")
|
211 |
+
image_output = gr.Image(label="Generated Image")
|
212 |
+
generate_button.click(
|
213 |
+
generate_image_gradio,
|
214 |
+
inputs=[text_input, beta_input, seed_input, more_smooth_input, top_k_input, top_p_input],
|
215 |
+
outputs=image_output
|
216 |
+
)
|
217 |
+
|
218 |
+
with gr.Tab("Generate Video"):
|
219 |
+
text_input_video = gr.Textbox(label="Input Text")
|
220 |
+
start_beta_input = gr.Slider(label="Start Beta", minimum=0.0, maximum=2.5, step=0.05, value=0)
|
221 |
+
target_beta_input = gr.Slider(label="Target Beta",minimum=0.0, maximum=2.5, step=0.05, value=1.0)
|
222 |
+
fps_input = gr.Number(label="FPS", value=10)
|
223 |
+
length_input = gr.Number(label="Length (seconds)", value=5.0)
|
224 |
+
seed_input_video = gr.Number(label="Seed", value=None)
|
225 |
+
more_smooth_input_video = gr.Checkbox(label="More Smooth", value=False)
|
226 |
+
top_k_input_video = gr.Number(label="Top K", value=0)
|
227 |
+
top_p_input_video = gr.Slider(label="Top P", minimum=0.0, maximum=1.0, step=0.01, value=0.9)
|
228 |
+
generate_video_button = gr.Button("Generate Video")
|
229 |
+
frame_output = gr.Image(label="Current Frame")
|
230 |
+
video_output = gr.Video(label="Generated Video")
|
231 |
+
|
232 |
+
generate_video_button.click(
|
233 |
+
generate_video_gradio,
|
234 |
+
inputs=[text_input_video, start_beta_input, target_beta_input, fps_input, length_input, top_k_input_video, top_p_input_video, seed_input_video, more_smooth_input_video],
|
235 |
+
outputs=[frame_output, video_output],
|
236 |
+
queue=True # Enable queuing to allow for progress updates
|
237 |
+
)
|
238 |
+
|
239 |
+
demo.launch()
|
dist.py
ADDED
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import os
|
4 |
+
import sys
|
5 |
+
from typing import List
|
6 |
+
from typing import Union
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.distributed as tdist
|
10 |
+
import torch.multiprocessing as mp
|
11 |
+
|
12 |
+
__rank, __local_rank, __world_size, __device = 0, 0, 1, 'cuda' if torch.cuda.is_available() else 'cpu'
|
13 |
+
__initialized = False
|
14 |
+
|
15 |
+
|
16 |
+
def initialized():
|
17 |
+
return __initialized
|
18 |
+
|
19 |
+
|
20 |
+
def initialize(fork=False, backend='nccl', gpu_id_if_not_distibuted=0, timeout=30):
|
21 |
+
global __device
|
22 |
+
if not torch.cuda.is_available():
|
23 |
+
print(f'[dist initialize] cuda is not available, use cpu instead', file=sys.stderr)
|
24 |
+
return
|
25 |
+
elif 'RANK' not in os.environ:
|
26 |
+
torch.cuda.set_device(gpu_id_if_not_distibuted)
|
27 |
+
__device = torch.empty(1).cuda().device
|
28 |
+
print(f'[dist initialize] env variable "RANK" is not set, use {__device} as the device', file=sys.stderr)
|
29 |
+
return
|
30 |
+
# then 'RANK' must exist
|
31 |
+
global_rank, num_gpus = int(os.environ['RANK']), torch.cuda.device_count()
|
32 |
+
local_rank = global_rank % num_gpus
|
33 |
+
torch.cuda.set_device(local_rank)
|
34 |
+
|
35 |
+
# ref: https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py#L29
|
36 |
+
if mp.get_start_method(allow_none=True) is None:
|
37 |
+
method = 'fork' if fork else 'spawn'
|
38 |
+
print(f'[dist initialize] mp method={method}')
|
39 |
+
mp.set_start_method(method)
|
40 |
+
tdist.init_process_group(backend=backend, timeout=datetime.timedelta(seconds=timeout*60))
|
41 |
+
|
42 |
+
global __rank, __local_rank, __world_size, __initialized
|
43 |
+
__local_rank = local_rank
|
44 |
+
__rank, __world_size = tdist.get_rank(), tdist.get_world_size()
|
45 |
+
__device = torch.empty(1).cuda().device
|
46 |
+
__initialized = True
|
47 |
+
|
48 |
+
assert tdist.is_initialized(), 'torch.distributed is not initialized!'
|
49 |
+
print(f'[lrk={get_local_rank()}, rk={get_rank()}]')
|
50 |
+
|
51 |
+
|
52 |
+
def get_rank():
|
53 |
+
return __rank
|
54 |
+
|
55 |
+
|
56 |
+
def get_local_rank():
|
57 |
+
return __local_rank
|
58 |
+
|
59 |
+
|
60 |
+
def get_world_size():
|
61 |
+
return __world_size
|
62 |
+
|
63 |
+
|
64 |
+
def get_device():
|
65 |
+
return __device
|
66 |
+
|
67 |
+
|
68 |
+
def set_gpu_id(gpu_id: int):
|
69 |
+
if gpu_id is None: return
|
70 |
+
global __device
|
71 |
+
if isinstance(gpu_id, (str, int)):
|
72 |
+
torch.cuda.set_device(int(gpu_id))
|
73 |
+
__device = torch.empty(1).cuda().device
|
74 |
+
else:
|
75 |
+
raise NotImplementedError
|
76 |
+
|
77 |
+
|
78 |
+
def is_master():
|
79 |
+
return __rank == 0
|
80 |
+
|
81 |
+
|
82 |
+
def is_local_master():
|
83 |
+
return __local_rank == 0
|
84 |
+
|
85 |
+
|
86 |
+
def new_group(ranks: List[int]):
|
87 |
+
if __initialized:
|
88 |
+
return tdist.new_group(ranks=ranks)
|
89 |
+
return None
|
90 |
+
|
91 |
+
|
92 |
+
def barrier():
|
93 |
+
if __initialized:
|
94 |
+
tdist.barrier()
|
95 |
+
|
96 |
+
|
97 |
+
def allreduce(t: torch.Tensor, async_op=False):
|
98 |
+
if __initialized:
|
99 |
+
if not t.is_cuda:
|
100 |
+
cu = t.detach().cuda()
|
101 |
+
ret = tdist.all_reduce(cu, async_op=async_op)
|
102 |
+
t.copy_(cu.cpu())
|
103 |
+
else:
|
104 |
+
ret = tdist.all_reduce(t, async_op=async_op)
|
105 |
+
return ret
|
106 |
+
return None
|
107 |
+
|
108 |
+
|
109 |
+
def allgather(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
110 |
+
if __initialized:
|
111 |
+
if not t.is_cuda:
|
112 |
+
t = t.cuda()
|
113 |
+
ls = [torch.empty_like(t) for _ in range(__world_size)]
|
114 |
+
tdist.all_gather(ls, t)
|
115 |
+
else:
|
116 |
+
ls = [t]
|
117 |
+
if cat:
|
118 |
+
ls = torch.cat(ls, dim=0)
|
119 |
+
return ls
|
120 |
+
|
121 |
+
|
122 |
+
def allgather_diff_shape(t: torch.Tensor, cat=True) -> Union[List[torch.Tensor], torch.Tensor]:
|
123 |
+
if __initialized:
|
124 |
+
if not t.is_cuda:
|
125 |
+
t = t.cuda()
|
126 |
+
|
127 |
+
t_size = torch.tensor(t.size(), device=t.device)
|
128 |
+
ls_size = [torch.empty_like(t_size) for _ in range(__world_size)]
|
129 |
+
tdist.all_gather(ls_size, t_size)
|
130 |
+
|
131 |
+
max_B = max(size[0].item() for size in ls_size)
|
132 |
+
pad = max_B - t_size[0].item()
|
133 |
+
if pad:
|
134 |
+
pad_size = (pad, *t.size()[1:])
|
135 |
+
t = torch.cat((t, t.new_empty(pad_size)), dim=0)
|
136 |
+
|
137 |
+
ls_padded = [torch.empty_like(t) for _ in range(__world_size)]
|
138 |
+
tdist.all_gather(ls_padded, t)
|
139 |
+
ls = []
|
140 |
+
for t, size in zip(ls_padded, ls_size):
|
141 |
+
ls.append(t[:size[0].item()])
|
142 |
+
else:
|
143 |
+
ls = [t]
|
144 |
+
if cat:
|
145 |
+
ls = torch.cat(ls, dim=0)
|
146 |
+
return ls
|
147 |
+
|
148 |
+
|
149 |
+
def broadcast(t: torch.Tensor, src_rank) -> None:
|
150 |
+
if __initialized:
|
151 |
+
if not t.is_cuda:
|
152 |
+
cu = t.detach().cuda()
|
153 |
+
tdist.broadcast(cu, src=src_rank)
|
154 |
+
t.copy_(cu.cpu())
|
155 |
+
else:
|
156 |
+
tdist.broadcast(t, src=src_rank)
|
157 |
+
|
158 |
+
|
159 |
+
def dist_fmt_vals(val: float, fmt: Union[str, None] = '%.2f') -> Union[torch.Tensor, List]:
|
160 |
+
if not initialized():
|
161 |
+
return torch.tensor([val]) if fmt is None else [fmt % val]
|
162 |
+
|
163 |
+
ts = torch.zeros(__world_size)
|
164 |
+
ts[__rank] = val
|
165 |
+
allreduce(ts)
|
166 |
+
if fmt is None:
|
167 |
+
return ts
|
168 |
+
return [fmt % v for v in ts.cpu().numpy().tolist()]
|
169 |
+
|
170 |
+
|
171 |
+
def master_only(func):
|
172 |
+
@functools.wraps(func)
|
173 |
+
def wrapper(*args, **kwargs):
|
174 |
+
force = kwargs.pop('force', False)
|
175 |
+
if force or is_master():
|
176 |
+
ret = func(*args, **kwargs)
|
177 |
+
else:
|
178 |
+
ret = None
|
179 |
+
barrier()
|
180 |
+
return ret
|
181 |
+
return wrapper
|
182 |
+
|
183 |
+
|
184 |
+
def local_master_only(func):
|
185 |
+
@functools.wraps(func)
|
186 |
+
def wrapper(*args, **kwargs):
|
187 |
+
force = kwargs.pop('force', False)
|
188 |
+
if force or is_local_master():
|
189 |
+
ret = func(*args, **kwargs)
|
190 |
+
else:
|
191 |
+
ret = None
|
192 |
+
barrier()
|
193 |
+
return ret
|
194 |
+
return wrapper
|
195 |
+
|
196 |
+
|
197 |
+
def for_visualize(func):
|
198 |
+
@functools.wraps(func)
|
199 |
+
def wrapper(*args, **kwargs):
|
200 |
+
if is_master():
|
201 |
+
# with torch.no_grad():
|
202 |
+
ret = func(*args, **kwargs)
|
203 |
+
else:
|
204 |
+
ret = None
|
205 |
+
return ret
|
206 |
+
return wrapper
|
207 |
+
|
208 |
+
|
209 |
+
def finalize():
|
210 |
+
if __initialized:
|
211 |
+
tdist.destroy_process_group()
|
models/__init__.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from .quant import VectorQuantizer2
|
5 |
+
from .var import VAR
|
6 |
+
from .vqvae import VQVAE
|
7 |
+
|
8 |
+
|
9 |
+
def build_vae_var(
|
10 |
+
# Shared args
|
11 |
+
device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
12 |
+
# VQVAE args
|
13 |
+
V=4096, Cvae=32, ch=160, share_quant_resi=4,
|
14 |
+
# VAR args
|
15 |
+
num_classes=1000, depth=16, shared_aln=False, attn_l2_norm=True,
|
16 |
+
flash_if_available=True, fused_if_available=True,
|
17 |
+
init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=-1, # init_std < 0: automated
|
18 |
+
) -> Tuple[VQVAE, VAR]:
|
19 |
+
heads = depth
|
20 |
+
width = depth * 64
|
21 |
+
dpr = 0.1 * depth/24
|
22 |
+
|
23 |
+
# disable built-in initialization for speed
|
24 |
+
for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d, nn.ConvTranspose2d):
|
25 |
+
setattr(clz, 'reset_parameters', lambda self: None)
|
26 |
+
|
27 |
+
# build models
|
28 |
+
vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi, v_patch_nums=patch_nums).to(device)
|
29 |
+
var_wo_ddp = VAR(
|
30 |
+
vae_local=vae_local,
|
31 |
+
num_classes=num_classes, depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
|
32 |
+
norm_eps=1e-6, shared_aln=shared_aln, cond_drop_rate=0.1,
|
33 |
+
attn_l2_norm=attn_l2_norm,
|
34 |
+
patch_nums=patch_nums,
|
35 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
36 |
+
).to(device)
|
37 |
+
var_wo_ddp.init_weights(init_adaln=init_adaln, init_adaln_gamma=init_adaln_gamma, init_head=init_head, init_std=init_std)
|
38 |
+
|
39 |
+
return vae_local, var_wo_ddp
|
models/basic_vae.py
ADDED
@@ -0,0 +1,226 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
|
6 |
+
# this file only provides the 2 modules used in VQVAE
|
7 |
+
__all__ = ['Encoder', 'Decoder',]
|
8 |
+
|
9 |
+
|
10 |
+
"""
|
11 |
+
References: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py
|
12 |
+
"""
|
13 |
+
# swish
|
14 |
+
def nonlinearity(x):
|
15 |
+
return x * torch.sigmoid(x)
|
16 |
+
|
17 |
+
|
18 |
+
def Normalize(in_channels, num_groups=32):
|
19 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
20 |
+
|
21 |
+
|
22 |
+
class Upsample2x(nn.Module):
|
23 |
+
def __init__(self, in_channels):
|
24 |
+
super().__init__()
|
25 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
return self.conv(F.interpolate(x, scale_factor=2, mode='nearest'))
|
29 |
+
|
30 |
+
|
31 |
+
class Downsample2x(nn.Module):
|
32 |
+
def __init__(self, in_channels):
|
33 |
+
super().__init__()
|
34 |
+
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
return self.conv(F.pad(x, pad=(0, 1, 0, 1), mode='constant', value=0))
|
38 |
+
|
39 |
+
|
40 |
+
class ResnetBlock(nn.Module):
|
41 |
+
def __init__(self, *, in_channels, out_channels=None, dropout): # conv_shortcut=False, # conv_shortcut: always False in VAE
|
42 |
+
super().__init__()
|
43 |
+
self.in_channels = in_channels
|
44 |
+
out_channels = in_channels if out_channels is None else out_channels
|
45 |
+
self.out_channels = out_channels
|
46 |
+
|
47 |
+
self.norm1 = Normalize(in_channels)
|
48 |
+
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
49 |
+
self.norm2 = Normalize(out_channels)
|
50 |
+
self.dropout = torch.nn.Dropout(dropout) if dropout > 1e-6 else nn.Identity()
|
51 |
+
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
52 |
+
if self.in_channels != self.out_channels:
|
53 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
54 |
+
else:
|
55 |
+
self.nin_shortcut = nn.Identity()
|
56 |
+
|
57 |
+
def forward(self, x):
|
58 |
+
h = self.conv1(F.silu(self.norm1(x), inplace=True))
|
59 |
+
h = self.conv2(self.dropout(F.silu(self.norm2(h), inplace=True)))
|
60 |
+
return self.nin_shortcut(x) + h
|
61 |
+
|
62 |
+
|
63 |
+
class AttnBlock(nn.Module):
|
64 |
+
def __init__(self, in_channels):
|
65 |
+
super().__init__()
|
66 |
+
self.C = in_channels
|
67 |
+
|
68 |
+
self.norm = Normalize(in_channels)
|
69 |
+
self.qkv = torch.nn.Conv2d(in_channels, 3*in_channels, kernel_size=1, stride=1, padding=0)
|
70 |
+
self.w_ratio = int(in_channels) ** (-0.5)
|
71 |
+
self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
|
72 |
+
|
73 |
+
def forward(self, x):
|
74 |
+
qkv = self.qkv(self.norm(x))
|
75 |
+
B, _, H, W = qkv.shape # should be B,3C,H,W
|
76 |
+
C = self.C
|
77 |
+
q, k, v = qkv.reshape(B, 3, C, H, W).unbind(1)
|
78 |
+
|
79 |
+
# compute attention
|
80 |
+
q = q.view(B, C, H * W).contiguous()
|
81 |
+
q = q.permute(0, 2, 1).contiguous() # B,HW,C
|
82 |
+
k = k.view(B, C, H * W).contiguous() # B,C,HW
|
83 |
+
w = torch.bmm(q, k).mul_(self.w_ratio) # B,HW,HW w[B,i,j]=sum_c q[B,i,C]k[B,C,j]
|
84 |
+
w = F.softmax(w, dim=2)
|
85 |
+
|
86 |
+
# attend to values
|
87 |
+
v = v.view(B, C, H * W).contiguous()
|
88 |
+
w = w.permute(0, 2, 1).contiguous() # B,HW,HW (first HW of k, second of q)
|
89 |
+
h = torch.bmm(v, w) # B, C,HW (HW of q) h[B,C,j] = sum_i v[B,C,i] w[B,i,j]
|
90 |
+
h = h.view(B, C, H, W).contiguous()
|
91 |
+
|
92 |
+
return x + self.proj_out(h)
|
93 |
+
|
94 |
+
|
95 |
+
def make_attn(in_channels, using_sa=True):
|
96 |
+
return AttnBlock(in_channels) if using_sa else nn.Identity()
|
97 |
+
|
98 |
+
|
99 |
+
class Encoder(nn.Module):
|
100 |
+
def __init__(
|
101 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
102 |
+
dropout=0.0, in_channels=3,
|
103 |
+
z_channels, double_z=False, using_sa=True, using_mid_sa=True,
|
104 |
+
):
|
105 |
+
super().__init__()
|
106 |
+
self.ch = ch
|
107 |
+
self.num_resolutions = len(ch_mult)
|
108 |
+
self.downsample_ratio = 2 ** (self.num_resolutions - 1)
|
109 |
+
self.num_res_blocks = num_res_blocks
|
110 |
+
self.in_channels = in_channels
|
111 |
+
|
112 |
+
# downsampling
|
113 |
+
self.conv_in = torch.nn.Conv2d(in_channels, self.ch, kernel_size=3, stride=1, padding=1)
|
114 |
+
|
115 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
116 |
+
self.down = nn.ModuleList()
|
117 |
+
for i_level in range(self.num_resolutions):
|
118 |
+
block = nn.ModuleList()
|
119 |
+
attn = nn.ModuleList()
|
120 |
+
block_in = ch * in_ch_mult[i_level]
|
121 |
+
block_out = ch * ch_mult[i_level]
|
122 |
+
for i_block in range(self.num_res_blocks):
|
123 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
124 |
+
block_in = block_out
|
125 |
+
if i_level == self.num_resolutions - 1 and using_sa:
|
126 |
+
attn.append(make_attn(block_in, using_sa=True))
|
127 |
+
down = nn.Module()
|
128 |
+
down.block = block
|
129 |
+
down.attn = attn
|
130 |
+
if i_level != self.num_resolutions - 1:
|
131 |
+
down.downsample = Downsample2x(block_in)
|
132 |
+
self.down.append(down)
|
133 |
+
|
134 |
+
# middle
|
135 |
+
self.mid = nn.Module()
|
136 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
137 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
138 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
139 |
+
|
140 |
+
# end
|
141 |
+
self.norm_out = Normalize(block_in)
|
142 |
+
self.conv_out = torch.nn.Conv2d(block_in, (2 * z_channels if double_z else z_channels), kernel_size=3, stride=1, padding=1)
|
143 |
+
|
144 |
+
def forward(self, x):
|
145 |
+
# downsampling
|
146 |
+
h = self.conv_in(x)
|
147 |
+
for i_level in range(self.num_resolutions):
|
148 |
+
for i_block in range(self.num_res_blocks):
|
149 |
+
h = self.down[i_level].block[i_block](h)
|
150 |
+
if len(self.down[i_level].attn) > 0:
|
151 |
+
h = self.down[i_level].attn[i_block](h)
|
152 |
+
if i_level != self.num_resolutions - 1:
|
153 |
+
h = self.down[i_level].downsample(h)
|
154 |
+
|
155 |
+
# middle
|
156 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(h)))
|
157 |
+
|
158 |
+
# end
|
159 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
160 |
+
return h
|
161 |
+
|
162 |
+
|
163 |
+
class Decoder(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self, *, ch=128, ch_mult=(1, 2, 4, 8), num_res_blocks=2,
|
166 |
+
dropout=0.0, in_channels=3, # in_channels: raw img channels
|
167 |
+
z_channels, using_sa=True, using_mid_sa=True,
|
168 |
+
):
|
169 |
+
super().__init__()
|
170 |
+
self.ch = ch
|
171 |
+
self.num_resolutions = len(ch_mult)
|
172 |
+
self.num_res_blocks = num_res_blocks
|
173 |
+
self.in_channels = in_channels
|
174 |
+
|
175 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
176 |
+
in_ch_mult = (1,) + tuple(ch_mult)
|
177 |
+
block_in = ch * ch_mult[self.num_resolutions - 1]
|
178 |
+
|
179 |
+
# z to block_in
|
180 |
+
self.conv_in = torch.nn.Conv2d(z_channels, block_in, kernel_size=3, stride=1, padding=1)
|
181 |
+
|
182 |
+
# middle
|
183 |
+
self.mid = nn.Module()
|
184 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
185 |
+
self.mid.attn_1 = make_attn(block_in, using_sa=using_mid_sa)
|
186 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in, out_channels=block_in, dropout=dropout)
|
187 |
+
|
188 |
+
# upsampling
|
189 |
+
self.up = nn.ModuleList()
|
190 |
+
for i_level in reversed(range(self.num_resolutions)):
|
191 |
+
block = nn.ModuleList()
|
192 |
+
attn = nn.ModuleList()
|
193 |
+
block_out = ch * ch_mult[i_level]
|
194 |
+
for i_block in range(self.num_res_blocks + 1):
|
195 |
+
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out, dropout=dropout))
|
196 |
+
block_in = block_out
|
197 |
+
if i_level == self.num_resolutions-1 and using_sa:
|
198 |
+
attn.append(make_attn(block_in, using_sa=True))
|
199 |
+
up = nn.Module()
|
200 |
+
up.block = block
|
201 |
+
up.attn = attn
|
202 |
+
if i_level != 0:
|
203 |
+
up.upsample = Upsample2x(block_in)
|
204 |
+
self.up.insert(0, up) # prepend to get consistent order
|
205 |
+
|
206 |
+
# end
|
207 |
+
self.norm_out = Normalize(block_in)
|
208 |
+
self.conv_out = torch.nn.Conv2d(block_in, in_channels, kernel_size=3, stride=1, padding=1)
|
209 |
+
|
210 |
+
def forward(self, z):
|
211 |
+
# z to block_in
|
212 |
+
# middle
|
213 |
+
h = self.mid.block_2(self.mid.attn_1(self.mid.block_1(self.conv_in(z))))
|
214 |
+
|
215 |
+
# upsampling
|
216 |
+
for i_level in reversed(range(self.num_resolutions)):
|
217 |
+
for i_block in range(self.num_res_blocks + 1):
|
218 |
+
h = self.up[i_level].block[i_block](h)
|
219 |
+
if len(self.up[i_level].attn) > 0:
|
220 |
+
h = self.up[i_level].attn[i_block](h)
|
221 |
+
if i_level != 0:
|
222 |
+
h = self.up[i_level].upsample(h)
|
223 |
+
|
224 |
+
# end
|
225 |
+
h = self.conv_out(F.silu(self.norm_out(h), inplace=True))
|
226 |
+
return h
|
models/basic_var.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
from models.helpers import DropPath, drop_path
|
8 |
+
|
9 |
+
|
10 |
+
# this file only provides the 3 blocks used in VAR transformer
|
11 |
+
__all__ = ['FFN', 'AdaLNSelfAttn', 'AdaLNBeforeHead']
|
12 |
+
|
13 |
+
|
14 |
+
# automatically import fused operators
|
15 |
+
dropout_add_layer_norm = fused_mlp_func = memory_efficient_attention = flash_attn_func = None
|
16 |
+
try:
|
17 |
+
from flash_attn.ops.layer_norm import dropout_add_layer_norm
|
18 |
+
from flash_attn.ops.fused_dense import fused_mlp_func
|
19 |
+
except ImportError: pass
|
20 |
+
# automatically import faster attention implementations
|
21 |
+
try: from xformers.ops import memory_efficient_attention
|
22 |
+
except ImportError: pass
|
23 |
+
try: from flash_attn import flash_attn_func # qkv: BLHc, ret: BLHcq
|
24 |
+
except ImportError: pass
|
25 |
+
try: from torch.nn.functional import scaled_dot_product_attention as slow_attn # q, k, v: BHLc
|
26 |
+
except ImportError:
|
27 |
+
def slow_attn(query, key, value, scale: float, attn_mask=None, dropout_p=0.0):
|
28 |
+
attn = query.mul(scale) @ key.transpose(-2, -1) # BHLc @ BHcL => BHLL
|
29 |
+
if attn_mask is not None: attn.add_(attn_mask)
|
30 |
+
return (F.dropout(attn.softmax(dim=-1), p=dropout_p, inplace=True) if dropout_p > 0 else attn.softmax(dim=-1)) @ value
|
31 |
+
|
32 |
+
|
33 |
+
class FFN(nn.Module):
|
34 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, drop=0., fused_if_available=True):
|
35 |
+
super().__init__()
|
36 |
+
self.fused_mlp_func = fused_mlp_func if fused_if_available else None
|
37 |
+
out_features = out_features or in_features
|
38 |
+
hidden_features = hidden_features or in_features
|
39 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
40 |
+
self.act = nn.GELU(approximate='tanh')
|
41 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
42 |
+
self.drop = nn.Dropout(drop, inplace=True) if drop > 0 else nn.Identity()
|
43 |
+
|
44 |
+
def forward(self, x):
|
45 |
+
if self.fused_mlp_func is not None:
|
46 |
+
return self.drop(self.fused_mlp_func(
|
47 |
+
x=x, weight1=self.fc1.weight, weight2=self.fc2.weight, bias1=self.fc1.bias, bias2=self.fc2.bias,
|
48 |
+
activation='gelu_approx', save_pre_act=self.training, return_residual=False, checkpoint_lvl=0,
|
49 |
+
heuristic=0, process_group=None,
|
50 |
+
))
|
51 |
+
else:
|
52 |
+
return self.drop(self.fc2( self.act(self.fc1(x)) ))
|
53 |
+
|
54 |
+
def extra_repr(self) -> str:
|
55 |
+
return f'fused_mlp_func={self.fused_mlp_func is not None}'
|
56 |
+
|
57 |
+
|
58 |
+
class SelfAttention(nn.Module):
|
59 |
+
def __init__(
|
60 |
+
self, block_idx, embed_dim=768, num_heads=12,
|
61 |
+
attn_drop=0., proj_drop=0., attn_l2_norm=False, flash_if_available=True,
|
62 |
+
):
|
63 |
+
super().__init__()
|
64 |
+
assert embed_dim % num_heads == 0
|
65 |
+
self.block_idx, self.num_heads, self.head_dim = block_idx, num_heads, embed_dim // num_heads # =64
|
66 |
+
self.attn_l2_norm = attn_l2_norm
|
67 |
+
if self.attn_l2_norm:
|
68 |
+
self.scale = 1
|
69 |
+
self.scale_mul_1H11 = nn.Parameter(torch.full(size=(1, self.num_heads, 1, 1), fill_value=4.0).log(), requires_grad=True)
|
70 |
+
self.max_scale_mul = torch.log(torch.tensor(100)).item()
|
71 |
+
else:
|
72 |
+
self.scale = 0.25 / math.sqrt(self.head_dim)
|
73 |
+
|
74 |
+
self.mat_qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)
|
75 |
+
self.q_bias, self.v_bias = nn.Parameter(torch.zeros(embed_dim)), nn.Parameter(torch.zeros(embed_dim))
|
76 |
+
self.register_buffer('zero_k_bias', torch.zeros(embed_dim))
|
77 |
+
|
78 |
+
self.proj = nn.Linear(embed_dim, embed_dim)
|
79 |
+
self.proj_drop = nn.Dropout(proj_drop, inplace=True) if proj_drop > 0 else nn.Identity()
|
80 |
+
self.attn_drop: float = attn_drop
|
81 |
+
self.using_flash = flash_if_available and flash_attn_func is not None
|
82 |
+
self.using_xform = flash_if_available and memory_efficient_attention is not None
|
83 |
+
|
84 |
+
# only used during inference
|
85 |
+
self.caching, self.cached_k, self.cached_v = False, None, None
|
86 |
+
|
87 |
+
def kv_caching(self, enable: bool): self.caching, self.cached_k, self.cached_v = enable, None, None
|
88 |
+
|
89 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
90 |
+
def forward(self, x, attn_bias):
|
91 |
+
B, L, C = x.shape
|
92 |
+
|
93 |
+
qkv = F.linear(input=x, weight=self.mat_qkv.weight, bias=torch.cat((self.q_bias, self.zero_k_bias, self.v_bias))).view(B, L, 3, self.num_heads, self.head_dim)
|
94 |
+
main_type = qkv.dtype
|
95 |
+
# qkv: BL3Hc
|
96 |
+
|
97 |
+
using_flash = self.using_flash and attn_bias is None and qkv.dtype != torch.float32
|
98 |
+
if using_flash or self.using_xform: q, k, v = qkv.unbind(dim=2); dim_cat = 1 # q or k or v: BLHc
|
99 |
+
else: q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(dim=0); dim_cat = 2 # q or k or v: BHLc
|
100 |
+
|
101 |
+
if self.attn_l2_norm:
|
102 |
+
scale_mul = self.scale_mul_1H11.clamp_max(self.max_scale_mul).exp()
|
103 |
+
if using_flash or self.using_xform: scale_mul = scale_mul.transpose(1, 2) # 1H11 to 11H1
|
104 |
+
q = F.normalize(q, dim=-1).mul(scale_mul)
|
105 |
+
k = F.normalize(k, dim=-1)
|
106 |
+
|
107 |
+
if self.caching:
|
108 |
+
if self.cached_k is None: self.cached_k = k; self.cached_v = v
|
109 |
+
else: k = self.cached_k = torch.cat((self.cached_k, k), dim=dim_cat); v = self.cached_v = torch.cat((self.cached_v, v), dim=dim_cat)
|
110 |
+
|
111 |
+
dropout_p = self.attn_drop if self.training else 0.0
|
112 |
+
if using_flash:
|
113 |
+
oup = flash_attn_func(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), dropout_p=dropout_p, softmax_scale=self.scale).view(B, L, C)
|
114 |
+
elif self.using_xform:
|
115 |
+
oup = memory_efficient_attention(q.to(dtype=main_type), k.to(dtype=main_type), v.to(dtype=main_type), attn_bias=None if attn_bias is None else attn_bias.to(dtype=main_type).expand(B, self.num_heads, -1, -1), p=dropout_p, scale=self.scale).view(B, L, C)
|
116 |
+
else:
|
117 |
+
oup = slow_attn(query=q, key=k, value=v, scale=self.scale, attn_mask=attn_bias, dropout_p=dropout_p).transpose(1, 2).reshape(B, L, C)
|
118 |
+
|
119 |
+
return self.proj_drop(self.proj(oup))
|
120 |
+
# attn = (q @ k.transpose(-2, -1)).add_(attn_bias + self.local_rpb()) # BHLc @ BHcL => BHLL
|
121 |
+
# attn = self.attn_drop(attn.softmax(dim=-1))
|
122 |
+
# oup = (attn @ v).transpose_(1, 2).reshape(B, L, -1) # BHLL @ BHLc = BHLc => BLHc => BLC
|
123 |
+
|
124 |
+
def extra_repr(self) -> str:
|
125 |
+
return f'using_flash={self.using_flash}, using_xform={self.using_xform}, attn_l2_norm={self.attn_l2_norm}'
|
126 |
+
|
127 |
+
|
128 |
+
class AdaLNSelfAttn(nn.Module):
|
129 |
+
def __init__(
|
130 |
+
self, block_idx, last_drop_p, embed_dim, cond_dim, shared_aln: bool, norm_layer,
|
131 |
+
num_heads, mlp_ratio=4., drop=0., attn_drop=0., drop_path=0., attn_l2_norm=False,
|
132 |
+
flash_if_available=False, fused_if_available=True,
|
133 |
+
):
|
134 |
+
super(AdaLNSelfAttn, self).__init__()
|
135 |
+
self.block_idx, self.last_drop_p, self.C = block_idx, last_drop_p, embed_dim
|
136 |
+
self.C, self.D = embed_dim, cond_dim
|
137 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
138 |
+
self.attn = SelfAttention(block_idx=block_idx, embed_dim=embed_dim, num_heads=num_heads, attn_drop=attn_drop, proj_drop=drop, attn_l2_norm=attn_l2_norm, flash_if_available=flash_if_available)
|
139 |
+
self.ffn = FFN(in_features=embed_dim, hidden_features=round(embed_dim * mlp_ratio), drop=drop, fused_if_available=fused_if_available)
|
140 |
+
|
141 |
+
self.ln_wo_grad = norm_layer(embed_dim, elementwise_affine=False)
|
142 |
+
self.shared_aln = shared_aln
|
143 |
+
if self.shared_aln:
|
144 |
+
self.ada_gss = nn.Parameter(torch.randn(1, 1, 6, embed_dim) / embed_dim**0.5)
|
145 |
+
else:
|
146 |
+
lin = nn.Linear(cond_dim, 6*embed_dim)
|
147 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), lin)
|
148 |
+
|
149 |
+
self.fused_add_norm_fn = None
|
150 |
+
|
151 |
+
# NOTE: attn_bias is None during inference because kv cache is enabled
|
152 |
+
def forward(self, x, cond_BD, attn_bias): # C: embed_dim, D: cond_dim
|
153 |
+
if self.shared_aln:
|
154 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = (self.ada_gss + cond_BD).unbind(2) # 116C + B16C =unbind(2)=> 6 B1C
|
155 |
+
else:
|
156 |
+
gamma1, gamma2, scale1, scale2, shift1, shift2 = self.ada_lin(cond_BD).view(-1, 1, 6, self.C).unbind(2)
|
157 |
+
x = x + self.drop_path(self.attn( self.ln_wo_grad(x).mul(scale1.add(1)).add_(shift1), attn_bias=attn_bias ).mul_(gamma1))
|
158 |
+
x = x + self.drop_path(self.ffn( self.ln_wo_grad(x).mul(scale2.add(1)).add_(shift2) ).mul(gamma2)) # this mul(gamma2) cannot be in-placed when FusedMLP is used
|
159 |
+
return x
|
160 |
+
|
161 |
+
def extra_repr(self) -> str:
|
162 |
+
return f'shared_aln={self.shared_aln}'
|
163 |
+
|
164 |
+
|
165 |
+
class AdaLNBeforeHead(nn.Module):
|
166 |
+
def __init__(self, C, D, norm_layer): # C: embed_dim, D: cond_dim
|
167 |
+
super().__init__()
|
168 |
+
self.C, self.D = C, D
|
169 |
+
self.ln_wo_grad = norm_layer(C, elementwise_affine=False)
|
170 |
+
self.ada_lin = nn.Sequential(nn.SiLU(inplace=False), nn.Linear(D, 2*C))
|
171 |
+
|
172 |
+
def forward(self, x_BLC: torch.Tensor, cond_BD: torch.Tensor):
|
173 |
+
scale, shift = self.ada_lin(cond_BD).view(-1, 1, 2, self.C).unbind(2)
|
174 |
+
return self.ln_wo_grad(x_BLC).mul(scale.add(1)).add_(shift)
|
models/helpers.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn as nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
|
6 |
+
def sample_with_top_k_top_p_(logits_BlV: torch.Tensor, top_k: int = 0, top_p: float = 0.0, rng=None, num_samples=1) -> torch.Tensor: # return idx, shaped (B, l)
|
7 |
+
B, l, V = logits_BlV.shape
|
8 |
+
if top_k > 0:
|
9 |
+
idx_to_remove = logits_BlV < logits_BlV.topk(top_k, largest=True, sorted=False, dim=-1)[0].amin(dim=-1, keepdim=True)
|
10 |
+
logits_BlV.masked_fill_(idx_to_remove, -torch.inf)
|
11 |
+
if top_p > 0:
|
12 |
+
sorted_logits, sorted_idx = logits_BlV.sort(dim=-1, descending=False)
|
13 |
+
sorted_idx_to_remove = sorted_logits.softmax(dim=-1).cumsum_(dim=-1) <= (1 - top_p)
|
14 |
+
sorted_idx_to_remove[..., -1:] = False
|
15 |
+
logits_BlV.masked_fill_(sorted_idx_to_remove.scatter(sorted_idx.ndim - 1, sorted_idx, sorted_idx_to_remove), -torch.inf)
|
16 |
+
# sample (have to squeeze cuz torch.multinomial can only be used for 2D tensor)
|
17 |
+
replacement = num_samples >= 0
|
18 |
+
num_samples = abs(num_samples)
|
19 |
+
return torch.multinomial(logits_BlV.softmax(dim=-1).view(-1, V), num_samples=num_samples, replacement=replacement, generator=rng).view(B, l, num_samples)
|
20 |
+
|
21 |
+
|
22 |
+
def gumbel_softmax_with_rng(logits: torch.Tensor, tau: float = 1, hard: bool = False, eps: float = 1e-10, dim: int = -1, rng: torch.Generator = None) -> torch.Tensor:
|
23 |
+
if rng is None:
|
24 |
+
return F.gumbel_softmax(logits=logits, tau=tau, hard=hard, eps=eps, dim=dim)
|
25 |
+
|
26 |
+
gumbels = (-torch.empty_like(logits, memory_format=torch.legacy_contiguous_format).exponential_(generator=rng).log())
|
27 |
+
gumbels = (logits + gumbels) / tau
|
28 |
+
y_soft = gumbels.softmax(dim)
|
29 |
+
|
30 |
+
if hard:
|
31 |
+
index = y_soft.max(dim, keepdim=True)[1]
|
32 |
+
y_hard = torch.zeros_like(logits, memory_format=torch.legacy_contiguous_format).scatter_(dim, index, 1.0)
|
33 |
+
ret = y_hard - y_soft.detach() + y_soft
|
34 |
+
else:
|
35 |
+
ret = y_soft
|
36 |
+
return ret
|
37 |
+
|
38 |
+
|
39 |
+
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True): # taken from timm
|
40 |
+
if drop_prob == 0. or not training: return x
|
41 |
+
keep_prob = 1 - drop_prob
|
42 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
43 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
44 |
+
if keep_prob > 0.0 and scale_by_keep:
|
45 |
+
random_tensor.div_(keep_prob)
|
46 |
+
return x * random_tensor
|
47 |
+
|
48 |
+
|
49 |
+
class DropPath(nn.Module): # taken from timm
|
50 |
+
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
|
51 |
+
super(DropPath, self).__init__()
|
52 |
+
self.drop_prob = drop_prob
|
53 |
+
self.scale_by_keep = scale_by_keep
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
|
57 |
+
|
58 |
+
def extra_repr(self):
|
59 |
+
return f'(drop_prob=...)'
|
models/quant.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Optional, Sequence, Tuple, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch import distributed as tdist, nn as nn
|
6 |
+
from torch.nn import functional as F
|
7 |
+
|
8 |
+
import dist
|
9 |
+
|
10 |
+
# this file only provides the VectorQuantizer2 used in VQVAE
|
11 |
+
__all__ = ['VectorQuantizer2', ]
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer2(nn.Module):
|
15 |
+
# VQGAN originally use beta=1.0, never tried 0.25; SD seems using 0.25
|
16 |
+
def __init__(
|
17 |
+
self, vocab_size, Cvae, using_znorm, beta: float = 0.25,
|
18 |
+
default_qresi_counts=0, v_patch_nums=None, quant_resi=0.5, share_quant_resi=4, # share_quant_resi: args.qsr
|
19 |
+
):
|
20 |
+
super().__init__()
|
21 |
+
self.vocab_size: int = vocab_size
|
22 |
+
self.Cvae: int = Cvae
|
23 |
+
self.using_znorm: bool = using_znorm
|
24 |
+
self.v_patch_nums: Tuple[int] = v_patch_nums
|
25 |
+
|
26 |
+
self.quant_resi_ratio = quant_resi
|
27 |
+
if share_quant_resi == 0: # non-shared: \phi_{1 to K} for K scales
|
28 |
+
self.quant_resi = PhiNonShared(
|
29 |
+
[(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
|
30 |
+
range(default_qresi_counts or len(self.v_patch_nums))])
|
31 |
+
elif share_quant_resi == 1: # fully shared: only a single \phi for K scales
|
32 |
+
self.quant_resi = PhiShared(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity())
|
33 |
+
else: # partially shared: \phi_{1 to share_quant_resi} for K scales
|
34 |
+
self.quant_resi = PhiPartiallyShared(nn.ModuleList(
|
35 |
+
[(Phi(Cvae, quant_resi) if abs(quant_resi) > 1e-6 else nn.Identity()) for _ in
|
36 |
+
range(share_quant_resi)]))
|
37 |
+
|
38 |
+
self.register_buffer('ema_vocab_hit_SV', torch.full((len(self.v_patch_nums), self.vocab_size), fill_value=0.0))
|
39 |
+
self.record_hit = 0
|
40 |
+
|
41 |
+
self.beta: float = beta
|
42 |
+
self.embedding = nn.Embedding(self.vocab_size, self.Cvae)
|
43 |
+
|
44 |
+
# only used for progressive training of VAR (not supported yet, will be tested and supported in the future)
|
45 |
+
self.prog_si = -1 # progressive training: not supported yet, prog_si always -1
|
46 |
+
|
47 |
+
def eini(self, eini):
|
48 |
+
if eini > 0:
|
49 |
+
nn.init.trunc_normal_(self.embedding.weight.data, std=eini)
|
50 |
+
elif eini < 0:
|
51 |
+
self.embedding.weight.data.uniform_(-abs(eini) / self.vocab_size, abs(eini) / self.vocab_size)
|
52 |
+
|
53 |
+
def extra_repr(self) -> str:
|
54 |
+
return f'{self.v_patch_nums}, znorm={self.using_znorm}, beta={self.beta} | S={len(self.v_patch_nums)}, quant_resi={self.quant_resi_ratio}'
|
55 |
+
|
56 |
+
# ===================== `forward` is only used in VAE training =====================
|
57 |
+
def forward(self, f_BChw: torch.Tensor, ret_usages=False) -> Tuple[torch.Tensor, List[float], torch.Tensor]:
|
58 |
+
dtype = f_BChw.dtype
|
59 |
+
if dtype != torch.float32: f_BChw = f_BChw.float()
|
60 |
+
B, C, H, W = f_BChw.shape
|
61 |
+
f_no_grad = f_BChw.detach()
|
62 |
+
|
63 |
+
f_rest = f_no_grad.clone()
|
64 |
+
f_hat = torch.zeros_like(f_rest)
|
65 |
+
|
66 |
+
with torch.cuda.amp.autocast(enabled=False):
|
67 |
+
mean_vq_loss: torch.Tensor = 0.0
|
68 |
+
vocab_hit_V = torch.zeros(self.vocab_size, dtype=torch.float, device=f_BChw.device)
|
69 |
+
SN = len(self.v_patch_nums)
|
70 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
71 |
+
# find the nearest embedding
|
72 |
+
if self.using_znorm:
|
73 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
|
74 |
+
C) if (
|
75 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
76 |
+
rest_NC = F.normalize(rest_NC, dim=-1)
|
77 |
+
idx_N = torch.argmax(rest_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
78 |
+
else:
|
79 |
+
rest_NC = F.interpolate(f_rest, size=(pn, pn), mode='bilinear').permute(0, 2, 3, 1).reshape(-1,
|
80 |
+
C) if (
|
81 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
82 |
+
d_no_grad = torch.sum(rest_NC.square(), dim=1, keepdim=True) + torch.sum(
|
83 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False)
|
84 |
+
d_no_grad.addmm_(rest_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
85 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
86 |
+
|
87 |
+
hit_V = idx_N.bincount(minlength=self.vocab_size).float()
|
88 |
+
if self.training:
|
89 |
+
if dist.initialized(): handler = tdist.all_reduce(hit_V, async_op=True)
|
90 |
+
|
91 |
+
# calc loss
|
92 |
+
idx_Bhw = idx_N.view(B, pn, pn)
|
93 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
|
94 |
+
mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(
|
95 |
+
idx_Bhw).permute(0, 3, 1, 2).contiguous()
|
96 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
97 |
+
f_hat = f_hat + h_BChw
|
98 |
+
f_rest -= h_BChw
|
99 |
+
|
100 |
+
if self.training and dist.initialized():
|
101 |
+
handler.wait()
|
102 |
+
if self.record_hit == 0:
|
103 |
+
self.ema_vocab_hit_SV[si].copy_(hit_V)
|
104 |
+
elif self.record_hit < 100:
|
105 |
+
self.ema_vocab_hit_SV[si].mul_(0.9).add_(hit_V.mul(0.1))
|
106 |
+
else:
|
107 |
+
self.ema_vocab_hit_SV[si].mul_(0.99).add_(hit_V.mul(0.01))
|
108 |
+
self.record_hit += 1
|
109 |
+
vocab_hit_V.add_(hit_V)
|
110 |
+
mean_vq_loss += F.mse_loss(f_hat.data, f_BChw).mul_(self.beta) + F.mse_loss(f_hat, f_no_grad)
|
111 |
+
|
112 |
+
mean_vq_loss *= 1. / SN
|
113 |
+
f_hat = (f_hat.data - f_no_grad).add_(f_BChw)
|
114 |
+
|
115 |
+
margin = tdist.get_world_size() * (f_BChw.numel() / f_BChw.shape[1]) / self.vocab_size * 0.08
|
116 |
+
# margin = pn*pn / 100
|
117 |
+
if ret_usages:
|
118 |
+
usages = [(self.ema_vocab_hit_SV[si] >= margin).float().mean().item() * 100 for si, pn in
|
119 |
+
enumerate(self.v_patch_nums)]
|
120 |
+
else:
|
121 |
+
usages = None
|
122 |
+
return f_hat, usages, mean_vq_loss
|
123 |
+
|
124 |
+
# ===================== `forward` is only used in VAE training =====================
|
125 |
+
|
126 |
+
def embed_to_fhat(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale=True, last_one=False) -> Union[
|
127 |
+
List[torch.Tensor], torch.Tensor]:
|
128 |
+
ls_f_hat_BChw = []
|
129 |
+
B = ms_h_BChw[0].shape[0]
|
130 |
+
H = W = self.v_patch_nums[-1]
|
131 |
+
SN = len(self.v_patch_nums)
|
132 |
+
if all_to_max_scale:
|
133 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, H, W, dtype=torch.float32)
|
134 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
135 |
+
h_BChw = ms_h_BChw[si]
|
136 |
+
if si < len(self.v_patch_nums) - 1:
|
137 |
+
h_BChw = F.interpolate(h_BChw, size=(H, W), mode='bilinear')
|
138 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
139 |
+
f_hat.add_(h_BChw)
|
140 |
+
if last_one:
|
141 |
+
ls_f_hat_BChw = f_hat
|
142 |
+
else:
|
143 |
+
ls_f_hat_BChw.append(f_hat.clone())
|
144 |
+
else:
|
145 |
+
# WARNING: this is not the case in VQ-VAE training or inference (we'll interpolate every token map to the max H W, like above)
|
146 |
+
# WARNING: this should only be used for experimental purpose
|
147 |
+
f_hat = ms_h_BChw[0].new_zeros(B, self.Cvae, self.v_patch_nums[0], self.v_patch_nums[0],
|
148 |
+
dtype=torch.float32)
|
149 |
+
for si, pn in enumerate(self.v_patch_nums): # from small to large
|
150 |
+
f_hat = F.interpolate(f_hat, size=(pn, pn), mode='bilinear')
|
151 |
+
h_BChw = self.quant_resi[si / (SN - 1)](ms_h_BChw[si])
|
152 |
+
f_hat.add_(h_BChw)
|
153 |
+
if last_one:
|
154 |
+
ls_f_hat_BChw = f_hat
|
155 |
+
else:
|
156 |
+
ls_f_hat_BChw.append(f_hat)
|
157 |
+
|
158 |
+
return ls_f_hat_BChw
|
159 |
+
|
160 |
+
def f_to_idxBl_or_fhat(self, f_BChw: torch.Tensor, to_fhat: bool,
|
161 |
+
v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[
|
162 |
+
Union[torch.Tensor, torch.LongTensor]]: # z_BChw is the feature from inp_img_no_grad
|
163 |
+
B, C, H, W = f_BChw.shape
|
164 |
+
f_no_grad = f_BChw.detach()
|
165 |
+
f_rest = f_no_grad.clone()
|
166 |
+
f_hat = torch.zeros_like(f_rest)
|
167 |
+
|
168 |
+
f_hat_or_idx_Bl: List[torch.Tensor] = []
|
169 |
+
|
170 |
+
patch_hws = [(pn, pn) if isinstance(pn, int) else (pn[0], pn[1]) for pn in
|
171 |
+
(v_patch_nums or self.v_patch_nums)] # from small to large
|
172 |
+
assert patch_hws[-1][0] == H and patch_hws[-1][1] == W, f'{patch_hws[-1]=} != ({H=}, {W=})'
|
173 |
+
|
174 |
+
SN = len(patch_hws)
|
175 |
+
for si, (ph, pw) in enumerate(patch_hws): # from small to large
|
176 |
+
if 0 <= self.prog_si < si: break # progressive training: not supported yet, prog_si always -1
|
177 |
+
# find the nearest embedding
|
178 |
+
z_NC = F.interpolate(f_rest, size=(ph, pw), mode='bilinear').permute(0, 2, 3, 1).reshape(-1, C) if (
|
179 |
+
si != SN - 1) else f_rest.permute(0, 2, 3, 1).reshape(-1, C)
|
180 |
+
if self.using_znorm:
|
181 |
+
z_NC = F.normalize(z_NC, dim=-1)
|
182 |
+
idx_N = torch.argmax(z_NC @ F.normalize(self.embedding.weight.data.T, dim=0), dim=1)
|
183 |
+
else:
|
184 |
+
d_no_grad = torch.sum(z_NC.square(), dim=1, keepdim=True) + torch.sum(
|
185 |
+
self.embedding.weight.data.square(), dim=1, keepdim=False)
|
186 |
+
d_no_grad.addmm_(z_NC, self.embedding.weight.data.T, alpha=-2, beta=1) # (B*h*w, vocab_size)
|
187 |
+
idx_N = torch.argmin(d_no_grad, dim=1)
|
188 |
+
|
189 |
+
idx_Bhw = idx_N.view(B, ph, pw)
|
190 |
+
h_BChw = F.interpolate(self.embedding(idx_Bhw).permute(0, 3, 1, 2), size=(H, W),
|
191 |
+
mode='bilinear').contiguous() if (si != SN - 1) else self.embedding(idx_Bhw).permute(
|
192 |
+
0, 3, 1, 2).contiguous()
|
193 |
+
h_BChw = self.quant_resi[si / (SN - 1)](h_BChw)
|
194 |
+
f_hat.add_(h_BChw)
|
195 |
+
f_rest.sub_(h_BChw)
|
196 |
+
f_hat_or_idx_Bl.append(f_hat.clone() if to_fhat else idx_N.reshape(B, ph * pw))
|
197 |
+
|
198 |
+
return f_hat_or_idx_Bl
|
199 |
+
|
200 |
+
# ===================== idxBl_to_var_input: only used in VAR training, for getting teacher-forcing input =====================
|
201 |
+
def idxBl_to_var_input(self, gt_ms_idx_Bl: List[torch.Tensor]) -> torch.Tensor:
|
202 |
+
next_scales = []
|
203 |
+
B = gt_ms_idx_Bl[0].shape[0]
|
204 |
+
C = self.Cvae
|
205 |
+
H = W = self.v_patch_nums[-1]
|
206 |
+
SN = len(self.v_patch_nums)
|
207 |
+
|
208 |
+
f_hat = gt_ms_idx_Bl[0].new_zeros(B, C, H, W, dtype=torch.float32)
|
209 |
+
pn_next: int = self.v_patch_nums[0]
|
210 |
+
for si in range(SN - 1):
|
211 |
+
if self.prog_si == 0 or (
|
212 |
+
0 <= self.prog_si - 1 < si): break # progressive training: not supported yet, prog_si always -1
|
213 |
+
h_BChw = F.interpolate(self.embedding(gt_ms_idx_Bl[si]).transpose_(1, 2).view(B, C, pn_next, pn_next),
|
214 |
+
size=(H, W), mode='bilinear')
|
215 |
+
f_hat.add_(self.quant_resi[si / (SN - 1)](h_BChw))
|
216 |
+
pn_next = self.v_patch_nums[si + 1]
|
217 |
+
next_scales.append(
|
218 |
+
F.interpolate(f_hat, size=(pn_next, pn_next), mode='bilinear').view(B, C, -1).transpose(1, 2))
|
219 |
+
return torch.cat(next_scales, dim=1) if len(next_scales) else None # cat BlCs to BLC, this should be float32
|
220 |
+
|
221 |
+
# ===================== get_next_autoregressive_input: only used in VAR inference, for getting next step's input =====================
|
222 |
+
def get_next_autoregressive_input(self, si: int, SN: int, f_hat: torch.Tensor, h_BChw: torch.Tensor) -> Tuple[
|
223 |
+
Optional[torch.Tensor], torch.Tensor]: # only used in VAR inference
|
224 |
+
HW = self.v_patch_nums[-1]
|
225 |
+
if si != SN - 1:
|
226 |
+
h = self.quant_resi[si / (SN - 1)](
|
227 |
+
F.interpolate(h_BChw, size=(HW, HW), mode='bilinear')) # conv after upsample
|
228 |
+
f_hat.add_(h)
|
229 |
+
return f_hat, F.interpolate(f_hat, size=(self.v_patch_nums[si + 1], self.v_patch_nums[si + 1]),
|
230 |
+
mode='bilinear')
|
231 |
+
else:
|
232 |
+
h = self.quant_resi[si / (SN - 1)](h_BChw)
|
233 |
+
f_hat.add_(h)
|
234 |
+
return f_hat, f_hat
|
235 |
+
|
236 |
+
|
237 |
+
class Phi(nn.Conv2d):
|
238 |
+
def __init__(self, embed_dim, quant_resi):
|
239 |
+
ks = 3
|
240 |
+
super().__init__(in_channels=embed_dim, out_channels=embed_dim, kernel_size=ks, stride=1, padding=ks // 2)
|
241 |
+
self.resi_ratio = abs(quant_resi)
|
242 |
+
|
243 |
+
def forward(self, h_BChw):
|
244 |
+
return h_BChw.mul(1 - self.resi_ratio) + super().forward(h_BChw).mul_(self.resi_ratio)
|
245 |
+
|
246 |
+
|
247 |
+
class PhiShared(nn.Module):
|
248 |
+
def __init__(self, qresi: Phi):
|
249 |
+
super().__init__()
|
250 |
+
self.qresi: Phi = qresi
|
251 |
+
|
252 |
+
def __getitem__(self, _) -> Phi:
|
253 |
+
return self.qresi
|
254 |
+
|
255 |
+
|
256 |
+
class PhiPartiallyShared(nn.Module):
|
257 |
+
def __init__(self, qresi_ls: nn.ModuleList):
|
258 |
+
super().__init__()
|
259 |
+
self.qresi_ls = qresi_ls
|
260 |
+
K = len(qresi_ls)
|
261 |
+
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
262 |
+
|
263 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
264 |
+
return self.qresi_ls[np.argmin(np.abs(self.ticks - at_from_0_to_1)).item()]
|
265 |
+
|
266 |
+
def extra_repr(self) -> str:
|
267 |
+
return f'ticks={self.ticks}'
|
268 |
+
|
269 |
+
|
270 |
+
class PhiNonShared(nn.ModuleList):
|
271 |
+
def __init__(self, qresi: List):
|
272 |
+
super().__init__(qresi)
|
273 |
+
# self.qresi = qresi
|
274 |
+
K = len(qresi)
|
275 |
+
self.ticks = np.linspace(1 / 3 / K, 1 - 1 / 3 / K, K) if K == 4 else np.linspace(1 / 2 / K, 1 - 1 / 2 / K, K)
|
276 |
+
|
277 |
+
def __getitem__(self, at_from_0_to_1: float) -> Phi:
|
278 |
+
return super().__getitem__(np.argmin(np.abs(self.ticks - at_from_0_to_1)).item())
|
279 |
+
|
280 |
+
def extra_repr(self) -> str:
|
281 |
+
return f'ticks={self.ticks}'
|
models/var.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
from typing import Optional, Tuple, Union
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from huggingface_hub import PyTorchModelHubMixin
|
8 |
+
|
9 |
+
import dist
|
10 |
+
from models.basic_var import AdaLNBeforeHead, AdaLNSelfAttn
|
11 |
+
from models.helpers import gumbel_softmax_with_rng, sample_with_top_k_top_p_
|
12 |
+
from models.vqvae import VQVAE, VectorQuantizer2
|
13 |
+
|
14 |
+
|
15 |
+
class SharedAdaLin(nn.Linear):
|
16 |
+
def forward(self, cond_BD):
|
17 |
+
C = self.weight.shape[0] // 6
|
18 |
+
return super().forward(cond_BD).view(-1, 1, 6, C) # B16C
|
19 |
+
|
20 |
+
|
21 |
+
class VAR(nn.Module):
|
22 |
+
def __init__(
|
23 |
+
self, vae_local: VQVAE,
|
24 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
|
25 |
+
drop_path_rate=0.,
|
26 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
27 |
+
attn_l2_norm=False,
|
28 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
29 |
+
flash_if_available=True, fused_if_available=True,
|
30 |
+
):
|
31 |
+
super().__init__()
|
32 |
+
# 0. hyperparameters
|
33 |
+
assert embed_dim % num_heads == 0
|
34 |
+
self.Cvae, self.V = vae_local.Cvae, vae_local.vocab_size
|
35 |
+
self.depth, self.C, self.D, self.num_heads = depth, embed_dim, embed_dim, num_heads
|
36 |
+
|
37 |
+
self.cond_drop_rate = cond_drop_rate
|
38 |
+
self.prog_si = -1 # progressive training
|
39 |
+
|
40 |
+
self.patch_nums: Tuple[int] = patch_nums
|
41 |
+
self.L = sum(pn ** 2 for pn in self.patch_nums)
|
42 |
+
self.first_l = self.patch_nums[0] ** 2
|
43 |
+
self.begin_ends = []
|
44 |
+
cur = 0
|
45 |
+
for i, pn in enumerate(self.patch_nums):
|
46 |
+
self.begin_ends.append((cur, cur + pn ** 2))
|
47 |
+
cur += pn ** 2
|
48 |
+
|
49 |
+
self.num_stages_minus_1 = len(self.patch_nums) - 1
|
50 |
+
self.rng = torch.Generator(device="mps")
|
51 |
+
|
52 |
+
# 1. input (word) embedding
|
53 |
+
quant: VectorQuantizer2 = vae_local.quantize
|
54 |
+
self.vae_proxy: Tuple[VQVAE] = (vae_local,)
|
55 |
+
self.vae_quant_proxy: Tuple[VectorQuantizer2] = (quant,)
|
56 |
+
self.word_embed = nn.Linear(self.Cvae, self.C)
|
57 |
+
|
58 |
+
# 2. class embedding
|
59 |
+
init_std = math.sqrt(1 / self.C / 3)
|
60 |
+
self.num_classes = num_classes
|
61 |
+
self.uniform_prob = torch.full((1, num_classes), fill_value=1.0 / num_classes, dtype=torch.float32,
|
62 |
+
device=dist.get_device())
|
63 |
+
self.class_emb = nn.Embedding(self.num_classes + 1, self.C)
|
64 |
+
nn.init.trunc_normal_(self.class_emb.weight.data, mean=0, std=init_std)
|
65 |
+
self.pos_start = nn.Parameter(torch.empty(1, self.first_l, self.C))
|
66 |
+
nn.init.trunc_normal_(self.pos_start.data, mean=0, std=init_std)
|
67 |
+
|
68 |
+
# 3. absolute position embedding
|
69 |
+
pos_1LC = []
|
70 |
+
for i, pn in enumerate(self.patch_nums):
|
71 |
+
pe = torch.empty(1, pn * pn, self.C)
|
72 |
+
nn.init.trunc_normal_(pe, mean=0, std=init_std)
|
73 |
+
pos_1LC.append(pe)
|
74 |
+
pos_1LC = torch.cat(pos_1LC, dim=1) # 1, L, C
|
75 |
+
assert tuple(pos_1LC.shape) == (1, self.L, self.C)
|
76 |
+
self.pos_1LC = nn.Parameter(pos_1LC)
|
77 |
+
# level embedding (similar to GPT's segment embedding, used to distinguish different levels of token pyramid)
|
78 |
+
self.lvl_embed = nn.Embedding(len(self.patch_nums), self.C)
|
79 |
+
nn.init.trunc_normal_(self.lvl_embed.weight.data, mean=0, std=init_std)
|
80 |
+
|
81 |
+
# 4. backbone blocks
|
82 |
+
self.shared_ada_lin = nn.Sequential(nn.SiLU(inplace=False),
|
83 |
+
SharedAdaLin(self.D, 6 * self.C)) if shared_aln else nn.Identity()
|
84 |
+
|
85 |
+
norm_layer = partial(nn.LayerNorm, eps=norm_eps)
|
86 |
+
self.drop_path_rate = drop_path_rate
|
87 |
+
dpr = [x.item() for x in
|
88 |
+
torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule (linearly increasing)
|
89 |
+
self.blocks = nn.ModuleList([
|
90 |
+
AdaLNSelfAttn(
|
91 |
+
cond_dim=self.D, shared_aln=shared_aln,
|
92 |
+
block_idx=block_idx, embed_dim=self.C, norm_layer=norm_layer, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
93 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[block_idx],
|
94 |
+
last_drop_p=0 if block_idx == 0 else dpr[block_idx - 1],
|
95 |
+
attn_l2_norm=attn_l2_norm,
|
96 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
97 |
+
)
|
98 |
+
for block_idx in range(depth)
|
99 |
+
])
|
100 |
+
|
101 |
+
fused_add_norm_fns = [b.fused_add_norm_fn is not None for b in self.blocks]
|
102 |
+
self.using_fused_add_norm_fn = any(fused_add_norm_fns)
|
103 |
+
print(
|
104 |
+
f'\n[constructor] ==== flash_if_available={flash_if_available} ({sum(b.attn.using_flash for b in self.blocks)}/{self.depth}), fused_if_available={fused_if_available} (fusing_add_ln={sum(fused_add_norm_fns)}/{self.depth}, fusing_mlp={sum(b.ffn.fused_mlp_func is not None for b in self.blocks)}/{self.depth}) ==== \n'
|
105 |
+
f' [VAR config ] embed_dim={embed_dim}, num_heads={num_heads}, depth={depth}, mlp_ratio={mlp_ratio}\n'
|
106 |
+
f' [drop ratios ] drop_rate={drop_rate}, attn_drop_rate={attn_drop_rate}, drop_path_rate={drop_path_rate:g} ({torch.linspace(0, drop_path_rate, depth)})',
|
107 |
+
end='\n\n', flush=True
|
108 |
+
)
|
109 |
+
|
110 |
+
# 5. attention mask used in training (for masking out the future)
|
111 |
+
# it won't be used in inference, since kv cache is enabled
|
112 |
+
d: torch.Tensor = torch.cat([torch.full((pn * pn,), i) for i, pn in enumerate(self.patch_nums)]).view(1, self.L,
|
113 |
+
1)
|
114 |
+
dT = d.transpose(1, 2) # dT: 11L
|
115 |
+
lvl_1L = dT[:, 0].contiguous()
|
116 |
+
self.register_buffer('lvl_1L', lvl_1L)
|
117 |
+
attn_bias_for_masking = torch.where(d >= dT, 0., -torch.inf).reshape(1, 1, self.L, self.L)
|
118 |
+
self.register_buffer('attn_bias_for_masking', attn_bias_for_masking.contiguous())
|
119 |
+
|
120 |
+
# 6. classifier head
|
121 |
+
self.head_nm = AdaLNBeforeHead(self.C, self.D, norm_layer=norm_layer)
|
122 |
+
self.head = nn.Linear(self.C, self.V)
|
123 |
+
|
124 |
+
def get_logits(self, h_or_h_and_residual: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
125 |
+
cond_BD: Optional[torch.Tensor]):
|
126 |
+
if not isinstance(h_or_h_and_residual, torch.Tensor):
|
127 |
+
h, resi = h_or_h_and_residual # fused_add_norm must be used
|
128 |
+
h = resi + self.blocks[-1].drop_path(h)
|
129 |
+
else: # fused_add_norm is not used
|
130 |
+
h = h_or_h_and_residual
|
131 |
+
return self.head(self.head_nm(h.float(), cond_BD).float()).float()
|
132 |
+
|
133 |
+
@torch.no_grad()
|
134 |
+
def autoregressive_infer_cfg(
|
135 |
+
self, B: int, label_B: Optional[Union[int, torch.LongTensor]],
|
136 |
+
delta_condition: torch.Tensor, alpha: float, beta: float,
|
137 |
+
g_seed: Optional[int] = None, cfg=1.5, top_k=0, top_p=0.0,
|
138 |
+
more_smooth=False,
|
139 |
+
) -> torch.Tensor: # returns reconstructed image (B, 3, H, W) in [0, 1]
|
140 |
+
"""
|
141 |
+
Generate images using autoregressive inference with classifier-free guidance.
|
142 |
+
:param B: batch size
|
143 |
+
:param label_B: class labels; if None, randomly sampled
|
144 |
+
:param delta_condition: tensor of shape (B, D)
|
145 |
+
:param alpha: scalar weight for class embedding
|
146 |
+
:param beta: scalar weight for delta_condition
|
147 |
+
:param g_seed: random seed
|
148 |
+
:param cfg: classifier-free guidance ratio
|
149 |
+
:param top_k: top-k sampling
|
150 |
+
:param top_p: top-p sampling
|
151 |
+
:param more_smooth: smoothing the pred using gumbel softmax; only used in visualization, not used in FID/IS benchmarking
|
152 |
+
:return: reconstructed images (B, 3, H, W)
|
153 |
+
"""
|
154 |
+
if g_seed is None:
|
155 |
+
rng = None
|
156 |
+
else:
|
157 |
+
self.rng.manual_seed(g_seed)
|
158 |
+
rng = self.rng
|
159 |
+
|
160 |
+
device = self.lvl_1L.device
|
161 |
+
if label_B is None:
|
162 |
+
label_B = torch.multinomial(self.uniform_prob, num_samples=B, replacement=True, generator=rng).reshape(B)
|
163 |
+
elif isinstance(label_B, int):
|
164 |
+
label_B = torch.full((B,), fill_value=self.num_classes if label_B < 0 else label_B, device=device)
|
165 |
+
|
166 |
+
# Prepare labels for conditioned and unconditioned versions
|
167 |
+
label_B_cond = label_B
|
168 |
+
label_B_uncond = torch.full_like(label_B, fill_value=self.num_classes)
|
169 |
+
label_B = torch.cat((label_B_cond, label_B_uncond), dim=0) # shape (2B,)
|
170 |
+
|
171 |
+
# Prepare delta_condition for conditioned and unconditioned versions
|
172 |
+
delta_condition_uncond = torch.zeros_like(delta_condition)
|
173 |
+
delta_condition = torch.cat((delta_condition, delta_condition_uncond), dim=0) # shape (2B, D)
|
174 |
+
|
175 |
+
class_emb = self.class_emb(label_B) # shape (2B, D)
|
176 |
+
cond_BD = alpha * class_emb + beta * delta_condition # shape (2B, D)
|
177 |
+
|
178 |
+
sos = cond_BD.unsqueeze(1).expand(2 * B, self.first_l, -1) + self.pos_start.expand(2 * B, self.first_l, -1)
|
179 |
+
|
180 |
+
lvl_pos = self.lvl_embed(self.lvl_1L) + self.pos_1LC
|
181 |
+
next_token_map = sos + lvl_pos[:, :self.first_l]
|
182 |
+
|
183 |
+
cur_L = 0
|
184 |
+
f_hat = sos.new_zeros(B, self.Cvae, self.patch_nums[-1], self.patch_nums[-1])
|
185 |
+
|
186 |
+
for b in self.blocks:
|
187 |
+
b.attn.kv_caching(True)
|
188 |
+
for si, pn in enumerate(self.patch_nums): # si: i-th segment
|
189 |
+
ratio = si / self.num_stages_minus_1
|
190 |
+
cur_L += pn * pn
|
191 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
192 |
+
x = next_token_map
|
193 |
+
|
194 |
+
for b in self.blocks:
|
195 |
+
x = b(x=x, cond_BD=cond_BD_or_gss, attn_bias=None)
|
196 |
+
logits_BlV = self.get_logits(x, cond_BD)
|
197 |
+
|
198 |
+
t = cfg * ratio
|
199 |
+
logits_BlV = (1 + t) * logits_BlV[:B] - t * logits_BlV[B:]
|
200 |
+
|
201 |
+
idx_Bl = sample_with_top_k_top_p_(logits_BlV, rng=rng, top_k=top_k, top_p=top_p, num_samples=1)[:, :, 0]
|
202 |
+
if not more_smooth: # this is the default case
|
203 |
+
h_BChw = self.vae_quant_proxy[0].embedding(idx_Bl) # B, l, Cvae
|
204 |
+
else: # not used when evaluating FID/IS/Precision/Recall
|
205 |
+
gum_t = max(0.27 * (1 - ratio * 0.95), 0.005) # refer to mask-git
|
206 |
+
h_BChw = gumbel_softmax_with_rng(logits_BlV.mul(1 + ratio), tau=gum_t, hard=False, dim=-1, rng=rng) @ \
|
207 |
+
self.vae_quant_proxy[0].embedding.weight.unsqueeze(0)
|
208 |
+
|
209 |
+
h_BChw = h_BChw.transpose_(1, 2).reshape(B, self.Cvae, pn, pn)
|
210 |
+
f_hat, next_token_map = self.vae_quant_proxy[0].get_next_autoregressive_input(si, len(self.patch_nums),
|
211 |
+
f_hat, h_BChw)
|
212 |
+
if si != self.num_stages_minus_1: # prepare for next stage
|
213 |
+
next_token_map = next_token_map.view(B, self.Cvae, -1).transpose(1, 2)
|
214 |
+
next_token_map = self.word_embed(next_token_map) + lvl_pos[:,
|
215 |
+
cur_L:cur_L + self.patch_nums[si + 1] ** 2]
|
216 |
+
next_token_map = next_token_map.repeat(2, 1, 1) # double the batch sizes due to CFG
|
217 |
+
|
218 |
+
for b in self.blocks:
|
219 |
+
b.attn.kv_caching(False)
|
220 |
+
return self.vae_proxy[0].fhat_to_img(f_hat).add_(1).mul_(0.5) # de-normalize, from [-1, 1] to [0, 1]
|
221 |
+
|
222 |
+
def forward(self, label_B: torch.LongTensor, x_BLCv_wo_first_l: torch.Tensor, delta_condition: torch.Tensor,
|
223 |
+
alpha: float, beta: float) -> torch.Tensor:
|
224 |
+
"""
|
225 |
+
:param label_B: label_B
|
226 |
+
:param x_BLCv_wo_first_l: teacher forcing input (B, self.L-self.first_l, self.Cvae)
|
227 |
+
:param delta_condition: tensor of shape (B, D)
|
228 |
+
:param alpha: scalar weight for class embedding
|
229 |
+
:param beta: scalar weight for delta_condition
|
230 |
+
:return: logits BLV, V is vocab_size
|
231 |
+
"""
|
232 |
+
bg, ed = self.begin_ends[self.prog_si] if self.prog_si >= 0 else (0, self.L)
|
233 |
+
B = x_BLCv_wo_first_l.shape[0]
|
234 |
+
with torch.cuda.amp.autocast(enabled=False):
|
235 |
+
# Implement conditional dropout
|
236 |
+
drop_mask = torch.rand(B, device=label_B.device) < self.cond_drop_rate
|
237 |
+
label_B_dropped = torch.where(drop_mask, self.num_classes, label_B)
|
238 |
+
delta_condition_dropped = delta_condition.clone()
|
239 |
+
delta_condition_dropped[drop_mask] = 0.0 # Drop delta_condition
|
240 |
+
|
241 |
+
class_emb = self.class_emb(label_B_dropped)
|
242 |
+
cond_BD = alpha * class_emb + beta * delta_condition_dropped
|
243 |
+
|
244 |
+
sos = cond_BD.unsqueeze(1).expand(B, self.first_l, -1) + self.pos_start.expand(B, self.first_l, -1)
|
245 |
+
|
246 |
+
if self.prog_si == 0:
|
247 |
+
x_BLC = sos
|
248 |
+
else:
|
249 |
+
x_BLC = torch.cat((sos, self.word_embed(x_BLCv_wo_first_l.float())), dim=1)
|
250 |
+
x_BLC += self.lvl_embed(self.lvl_1L[:, :ed].expand(B, -1)) + self.pos_1LC[:, :ed] # lvl: BLC; pos: 1LC
|
251 |
+
|
252 |
+
attn_bias = self.attn_bias_for_masking[:, :, :ed, :ed]
|
253 |
+
cond_BD_or_gss = self.shared_ada_lin(cond_BD)
|
254 |
+
|
255 |
+
# hack: get the dtype if mixed precision is used
|
256 |
+
temp = x_BLC.new_ones(8, 8)
|
257 |
+
main_type = torch.matmul(temp, temp).dtype
|
258 |
+
|
259 |
+
x_BLC = x_BLC.to(dtype=main_type)
|
260 |
+
cond_BD_or_gss = cond_BD_or_gss.to(dtype=main_type)
|
261 |
+
attn_bias = attn_bias.to(dtype=main_type)
|
262 |
+
|
263 |
+
AdaLNSelfAttn.forward
|
264 |
+
for i, b in enumerate(self.blocks):
|
265 |
+
x_BLC = b(x=x_BLC, cond_BD=cond_BD_or_gss, attn_bias=attn_bias)
|
266 |
+
x_BLC = self.get_logits(x_BLC.float(), cond_BD)
|
267 |
+
|
268 |
+
if self.prog_si == 0:
|
269 |
+
if isinstance(self.word_embed, nn.Linear):
|
270 |
+
x_BLC[0, 0, 0] += self.word_embed.weight[0, 0] * 0 + self.word_embed.bias[0] * 0
|
271 |
+
else:
|
272 |
+
s = 0
|
273 |
+
for p in self.word_embed.parameters():
|
274 |
+
if p.requires_grad:
|
275 |
+
s += p.view(-1)[0] * 0
|
276 |
+
x_BLC[0, 0, 0] += s
|
277 |
+
return x_BLC # logits BLV, V is vocab_size
|
278 |
+
|
279 |
+
def init_weights(self, init_adaln=0.5, init_adaln_gamma=1e-5, init_head=0.02, init_std=0.02, conv_std_or_gain=0.02):
|
280 |
+
if init_std < 0: init_std = (1 / self.C / 3) ** 0.5 # init_std < 0: automated
|
281 |
+
|
282 |
+
print(f'[init_weights] {type(self).__name__} with {init_std=:g}')
|
283 |
+
for m in self.modules():
|
284 |
+
with_weight = hasattr(m, 'weight') and m.weight is not None
|
285 |
+
with_bias = hasattr(m, 'bias') and m.bias is not None
|
286 |
+
if isinstance(m, nn.Linear):
|
287 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
288 |
+
if with_bias: m.bias.data.zero_()
|
289 |
+
elif isinstance(m, nn.Embedding):
|
290 |
+
nn.init.trunc_normal_(m.weight.data, std=init_std)
|
291 |
+
if m.padding_idx is not None: m.weight.data[m.padding_idx].zero_()
|
292 |
+
elif isinstance(m, (
|
293 |
+
nn.LayerNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm, nn.GroupNorm,
|
294 |
+
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d)):
|
295 |
+
if with_weight: m.weight.data.fill_(1.)
|
296 |
+
if with_bias: m.bias.data.zero_()
|
297 |
+
# conv: VAR has no conv, only VQVAE has conv
|
298 |
+
elif isinstance(m, (
|
299 |
+
nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
|
300 |
+
if conv_std_or_gain > 0:
|
301 |
+
nn.init.trunc_normal_(m.weight.data, std=conv_std_or_gain)
|
302 |
+
else:
|
303 |
+
nn.init.xavier_normal_(m.weight.data, gain=-conv_std_or_gain)
|
304 |
+
if with_bias: m.bias.data.zero_()
|
305 |
+
|
306 |
+
if init_head >= 0:
|
307 |
+
if isinstance(self.head, nn.Linear):
|
308 |
+
self.head.weight.data.mul_(init_head)
|
309 |
+
self.head.bias.data.zero_()
|
310 |
+
elif isinstance(self.head, nn.Sequential):
|
311 |
+
self.head[-1].weight.data.mul_(init_head)
|
312 |
+
self.head[-1].bias.data.zero_()
|
313 |
+
|
314 |
+
if isinstance(self.head_nm, AdaLNBeforeHead):
|
315 |
+
self.head_nm.ada_lin[-1].weight.data.mul_(init_adaln)
|
316 |
+
if hasattr(self.head_nm.ada_lin[-1], 'bias') and self.head_nm.ada_lin[-1].bias is not None:
|
317 |
+
self.head_nm.ada_lin[-1].bias.data.zero_()
|
318 |
+
|
319 |
+
depth = len(self.blocks)
|
320 |
+
for block_idx, sab in enumerate(self.blocks):
|
321 |
+
sab: AdaLNSelfAttn
|
322 |
+
sab.attn.proj.weight.data.div_(math.sqrt(2 * depth))
|
323 |
+
sab.ffn.fc2.weight.data.div_(math.sqrt(2 * depth))
|
324 |
+
if hasattr(sab.ffn, 'fcg') and sab.ffn.fcg is not None:
|
325 |
+
nn.init.ones_(sab.ffn.fcg.bias)
|
326 |
+
nn.init.trunc_normal_(sab.ffn.fcg.weight, std=1e-5)
|
327 |
+
if hasattr(sab, 'ada_lin'):
|
328 |
+
sab.ada_lin[-1].weight.data[2 * self.C:].mul_(init_adaln)
|
329 |
+
sab.ada_lin[-1].weight.data[:2 * self.C].mul_(init_adaln_gamma)
|
330 |
+
if hasattr(sab.ada_lin[-1], 'bias') and sab.ada_lin[-1].bias is not None:
|
331 |
+
sab.ada_lin[-1].bias.data.zero_()
|
332 |
+
elif hasattr(sab, 'ada_gss'):
|
333 |
+
sab.ada_gss.data[:, :, 2:].mul_(init_adaln)
|
334 |
+
sab.ada_gss.data[:, :, :2].mul_(init_adaln_gamma)
|
335 |
+
|
336 |
+
def extra_repr(self):
|
337 |
+
return f'drop_path_rate={self.drop_path_rate:g}'
|
338 |
+
|
339 |
+
|
340 |
+
class VARHF(VAR, PyTorchModelHubMixin):
|
341 |
+
def __init__(
|
342 |
+
self,
|
343 |
+
vae_kwargs,
|
344 |
+
num_classes=1000, depth=16, embed_dim=1024, num_heads=16, mlp_ratio=4., drop_rate=0., attn_drop_rate=0.,
|
345 |
+
drop_path_rate=0.,
|
346 |
+
norm_eps=1e-6, shared_aln=False, cond_drop_rate=0.1,
|
347 |
+
attn_l2_norm=False,
|
348 |
+
patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # 10 steps by default
|
349 |
+
flash_if_available=True, fused_if_available=True,
|
350 |
+
):
|
351 |
+
vae_local = VQVAE(**vae_kwargs)
|
352 |
+
super().__init__(
|
353 |
+
vae_local=vae_local,
|
354 |
+
num_classes=num_classes, depth=depth, embed_dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
355 |
+
drop_rate=drop_rate, attn_drop_rate=attn_drop_rate, drop_path_rate=drop_path_rate,
|
356 |
+
norm_eps=norm_eps, shared_aln=shared_aln, cond_drop_rate=cond_drop_rate,
|
357 |
+
attn_l2_norm=attn_l2_norm,
|
358 |
+
patch_nums=patch_nums,
|
359 |
+
flash_if_available=flash_if_available, fused_if_available=fused_if_available,
|
360 |
+
)
|
models/vqvae.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
References:
|
3 |
+
- VectorQuantizer2: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L110
|
4 |
+
- GumbelQuantize: https://github.com/CompVis/taming-transformers/blob/3ba01b241669f5ade541ce990f7650a3b8f65318/taming/modules/vqvae/quantize.py#L213
|
5 |
+
- VQVAE (VQModel): https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/models/autoencoder.py#L14
|
6 |
+
"""
|
7 |
+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
from .basic_vae import Decoder, Encoder
|
13 |
+
from .quant import VectorQuantizer2
|
14 |
+
|
15 |
+
|
16 |
+
class VQVAE(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self, vocab_size=4096, z_channels=32, ch=128, dropout=0.0,
|
19 |
+
beta=0.25, # commitment loss weight
|
20 |
+
using_znorm=False, # whether to normalize when computing the nearest neighbors
|
21 |
+
quant_conv_ks=3, # quant conv kernel size
|
22 |
+
quant_resi=0.5, # 0.5 means \phi(x) = 0.5conv(x) + (1-0.5)x
|
23 |
+
share_quant_resi=4, # use 4 \phi layers for K scales: partially-shared \phi
|
24 |
+
default_qresi_counts=0, # if is 0: automatically set to len(v_patch_nums)
|
25 |
+
v_patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16), # number of patches for each scale, h_{1 to K} = w_{1 to K} = v_patch_nums[k]
|
26 |
+
test_mode=True,
|
27 |
+
):
|
28 |
+
super().__init__()
|
29 |
+
self.test_mode = test_mode
|
30 |
+
self.V, self.Cvae = vocab_size, z_channels
|
31 |
+
# ddconfig is copied from https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/vq-f16/config.yaml
|
32 |
+
ddconfig = dict(
|
33 |
+
dropout=dropout, ch=ch, z_channels=z_channels,
|
34 |
+
in_channels=3, ch_mult=(1, 1, 2, 2, 4), num_res_blocks=2, # from vq-f16/config.yaml above
|
35 |
+
using_sa=True, using_mid_sa=True, # from vq-f16/config.yaml above
|
36 |
+
# resamp_with_conv=True, # always True, removed.
|
37 |
+
)
|
38 |
+
ddconfig.pop('double_z', None) # only KL-VAE should use double_z=True
|
39 |
+
self.encoder = Encoder(double_z=False, **ddconfig)
|
40 |
+
self.decoder = Decoder(**ddconfig)
|
41 |
+
|
42 |
+
self.vocab_size = vocab_size
|
43 |
+
self.downsample = 2 ** (len(ddconfig['ch_mult'])-1)
|
44 |
+
self.quantize: VectorQuantizer2 = VectorQuantizer2(
|
45 |
+
vocab_size=vocab_size, Cvae=self.Cvae, using_znorm=using_znorm, beta=beta,
|
46 |
+
default_qresi_counts=default_qresi_counts, v_patch_nums=v_patch_nums, quant_resi=quant_resi, share_quant_resi=share_quant_resi,
|
47 |
+
)
|
48 |
+
self.quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
49 |
+
self.post_quant_conv = torch.nn.Conv2d(self.Cvae, self.Cvae, quant_conv_ks, stride=1, padding=quant_conv_ks//2)
|
50 |
+
|
51 |
+
if self.test_mode:
|
52 |
+
self.eval()
|
53 |
+
[p.requires_grad_(False) for p in self.parameters()]
|
54 |
+
|
55 |
+
# ===================== `forward` is only used in VAE training =====================
|
56 |
+
def forward(self, inp, ret_usages=False): # -> rec_B3HW, idx_N, loss
|
57 |
+
VectorQuantizer2.forward
|
58 |
+
f_hat, usages, vq_loss = self.quantize(self.quant_conv(self.encoder(inp)), ret_usages=ret_usages)
|
59 |
+
return self.decoder(self.post_quant_conv(f_hat)), usages, vq_loss
|
60 |
+
# ===================== `forward` is only used in VAE training =====================
|
61 |
+
|
62 |
+
def fhat_to_img(self, f_hat: torch.Tensor):
|
63 |
+
return self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1)
|
64 |
+
|
65 |
+
def img_to_idxBl(self, inp_img_no_grad: torch.Tensor, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None) -> List[torch.LongTensor]: # return List[Bl]
|
66 |
+
f = self.quant_conv(self.encoder(inp_img_no_grad))
|
67 |
+
return self.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=v_patch_nums)
|
68 |
+
|
69 |
+
def idxBl_to_img(self, ms_idx_Bl: List[torch.Tensor], same_shape: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
70 |
+
B = ms_idx_Bl[0].shape[0]
|
71 |
+
ms_h_BChw = []
|
72 |
+
for idx_Bl in ms_idx_Bl:
|
73 |
+
l = idx_Bl.shape[1]
|
74 |
+
pn = round(l ** 0.5)
|
75 |
+
ms_h_BChw.append(self.quantize.embedding(idx_Bl).transpose(1, 2).view(B, self.Cvae, pn, pn))
|
76 |
+
return self.embed_to_img(ms_h_BChw=ms_h_BChw, all_to_max_scale=same_shape, last_one=last_one)
|
77 |
+
|
78 |
+
def embed_to_img(self, ms_h_BChw: List[torch.Tensor], all_to_max_scale: bool, last_one=False) -> Union[List[torch.Tensor], torch.Tensor]:
|
79 |
+
if last_one:
|
80 |
+
return self.decoder(self.post_quant_conv(self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=True))).clamp_(-1, 1)
|
81 |
+
else:
|
82 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in self.quantize.embed_to_fhat(ms_h_BChw, all_to_max_scale=all_to_max_scale, last_one=False)]
|
83 |
+
|
84 |
+
def img_to_reconstructed_img(self, x, v_patch_nums: Optional[Sequence[Union[int, Tuple[int, int]]]] = None, last_one=False) -> List[torch.Tensor]:
|
85 |
+
f = self.quant_conv(self.encoder(x))
|
86 |
+
ls_f_hat_BChw = self.quantize.f_to_idxBl_or_fhat(f, to_fhat=True, v_patch_nums=v_patch_nums)
|
87 |
+
if last_one:
|
88 |
+
return self.decoder(self.post_quant_conv(ls_f_hat_BChw[-1])).clamp_(-1, 1)
|
89 |
+
else:
|
90 |
+
return [self.decoder(self.post_quant_conv(f_hat)).clamp_(-1, 1) for f_hat in ls_f_hat_BChw]
|
91 |
+
|
92 |
+
def load_state_dict(self, state_dict: Dict[str, Any], strict=True, assign=False):
|
93 |
+
if 'quantize.ema_vocab_hit_SV' in state_dict and state_dict['quantize.ema_vocab_hit_SV'].shape[0] != self.quantize.ema_vocab_hit_SV.shape[0]:
|
94 |
+
state_dict['quantize.ema_vocab_hit_SV'] = self.quantize.ema_vocab_hit_SV
|
95 |
+
return super().load_state_dict(state_dict=state_dict, strict=strict, assign=assign)
|
utils/amp_sc.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from typing import List, Optional, Tuple, Union
|
3 |
+
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class NullCtx:
|
8 |
+
def __enter__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
12 |
+
pass
|
13 |
+
|
14 |
+
|
15 |
+
class AmpOptimizer:
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
mixed_precision: int,
|
19 |
+
optimizer: torch.optim.Optimizer, names: List[str], paras: List[torch.nn.Parameter],
|
20 |
+
grad_clip: float, n_gradient_accumulation: int = 1,
|
21 |
+
):
|
22 |
+
self.enable_amp = mixed_precision > 0
|
23 |
+
self.using_fp16_rather_bf16 = mixed_precision == 1
|
24 |
+
|
25 |
+
if self.enable_amp:
|
26 |
+
self.amp_ctx = torch.autocast('cuda', enabled=True, dtype=torch.float16 if self.using_fp16_rather_bf16 else torch.bfloat16, cache_enabled=True)
|
27 |
+
self.scaler = torch.cuda.amp.GradScaler(init_scale=2. ** 11, growth_interval=1000) if self.using_fp16_rather_bf16 else None # only fp16 needs a scaler
|
28 |
+
else:
|
29 |
+
self.amp_ctx = NullCtx()
|
30 |
+
self.scaler = None
|
31 |
+
|
32 |
+
self.optimizer, self.names, self.paras = optimizer, names, paras # paras have been filtered so everyone requires grad
|
33 |
+
self.grad_clip = grad_clip
|
34 |
+
self.early_clipping = self.grad_clip > 0 and not hasattr(optimizer, 'global_grad_norm')
|
35 |
+
self.late_clipping = self.grad_clip > 0 and hasattr(optimizer, 'global_grad_norm')
|
36 |
+
|
37 |
+
self.r_accu = 1 / n_gradient_accumulation # r_accu == 1.0 / n_gradient_accumulation
|
38 |
+
|
39 |
+
def backward_clip_step(
|
40 |
+
self, stepping: bool, loss: torch.Tensor,
|
41 |
+
) -> Tuple[Optional[Union[torch.Tensor, float]], Optional[float]]:
|
42 |
+
# backward
|
43 |
+
loss = loss.mul(self.r_accu) # r_accu == 1.0 / n_gradient_accumulation
|
44 |
+
orig_norm = scaler_sc = None
|
45 |
+
if self.scaler is not None:
|
46 |
+
self.scaler.scale(loss).backward(retain_graph=False, create_graph=False)
|
47 |
+
else:
|
48 |
+
loss.backward(retain_graph=False, create_graph=False)
|
49 |
+
|
50 |
+
if stepping:
|
51 |
+
if self.scaler is not None: self.scaler.unscale_(self.optimizer)
|
52 |
+
if self.early_clipping:
|
53 |
+
orig_norm = torch.nn.utils.clip_grad_norm_(self.paras, self.grad_clip)
|
54 |
+
|
55 |
+
if self.scaler is not None:
|
56 |
+
self.scaler.step(self.optimizer)
|
57 |
+
scaler_sc: float = self.scaler.get_scale()
|
58 |
+
if scaler_sc > 32768.: # fp16 will overflow when >65536, so multiply 32768 could be dangerous
|
59 |
+
self.scaler.update(new_scale=32768.)
|
60 |
+
else:
|
61 |
+
self.scaler.update()
|
62 |
+
try:
|
63 |
+
scaler_sc = float(math.log2(scaler_sc))
|
64 |
+
except Exception as e:
|
65 |
+
print(f'[scaler_sc = {scaler_sc}]\n' * 15, flush=True)
|
66 |
+
raise e
|
67 |
+
else:
|
68 |
+
self.optimizer.step()
|
69 |
+
|
70 |
+
if self.late_clipping:
|
71 |
+
orig_norm = self.optimizer.global_grad_norm
|
72 |
+
|
73 |
+
self.optimizer.zero_grad(set_to_none=True)
|
74 |
+
|
75 |
+
return orig_norm, scaler_sc
|
76 |
+
|
77 |
+
def state_dict(self):
|
78 |
+
return {
|
79 |
+
'optimizer': self.optimizer.state_dict()
|
80 |
+
} if self.scaler is None else {
|
81 |
+
'scaler': self.scaler.state_dict(),
|
82 |
+
'optimizer': self.optimizer.state_dict()
|
83 |
+
}
|
84 |
+
|
85 |
+
def load_state_dict(self, state, strict=True):
|
86 |
+
if self.scaler is not None:
|
87 |
+
try: self.scaler.load_state_dict(state['scaler'])
|
88 |
+
except Exception as e: print(f'[fp16 load_state_dict err] {e}')
|
89 |
+
self.optimizer.load_state_dict(state['optimizer'])
|
utils/arg_util.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import re
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import OrderedDict
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
|
14 |
+
try:
|
15 |
+
from tap import Tap
|
16 |
+
except ImportError as e:
|
17 |
+
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
|
18 |
+
print(f'`>>>>>>>> from tap import Tap` failed, please run: pip3 install typed-argument-parser <<<<<<<<', file=sys.stderr, flush=True)
|
19 |
+
time.sleep(5)
|
20 |
+
raise e
|
21 |
+
|
22 |
+
import dist
|
23 |
+
|
24 |
+
|
25 |
+
class Args(Tap):
|
26 |
+
data_path: str = '/path/to/imagenet'
|
27 |
+
exp_name: str = 'text'
|
28 |
+
|
29 |
+
# VAE
|
30 |
+
vfast: int = 0 # torch.compile VAE; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
|
31 |
+
# VAR
|
32 |
+
tfast: int = 0 # torch.compile VAR; =0: not compile; 1: compile with 'reduce-overhead'; 2: compile with 'max-autotune'
|
33 |
+
depth: int = 16 # VAR depth
|
34 |
+
# VAR initialization
|
35 |
+
ini: float = -1 # -1: automated model parameter initialization
|
36 |
+
hd: float = 0.02 # head.w *= hd
|
37 |
+
aln: float = 0.5 # the multiplier of ada_lin.w's initialization
|
38 |
+
alng: float = 1e-5 # the multiplier of ada_lin.w[gamma channels]'s initialization
|
39 |
+
# VAR optimization
|
40 |
+
fp16: int = 0 # 1: using fp16, 2: bf16
|
41 |
+
tblr: float = 1e-4 # base lr
|
42 |
+
tlr: float = None # lr = base lr * (bs / 256)
|
43 |
+
twd: float = 0.05 # initial wd
|
44 |
+
twde: float = 0 # final wd, =twde or twd
|
45 |
+
tclip: float = 2. # <=0 for not using grad clip
|
46 |
+
ls: float = 0.0 # label smooth
|
47 |
+
|
48 |
+
bs: int = 768 # global batch size
|
49 |
+
batch_size: int = 0 # [automatically set; don't specify this] batch size per GPU = round(args.bs / args.ac / dist.get_world_size() / 8) * 8
|
50 |
+
glb_batch_size: int = 0 # [automatically set; don't specify this] global batch size = args.batch_size * dist.get_world_size()
|
51 |
+
ac: int = 1 # gradient accumulation
|
52 |
+
|
53 |
+
ep: int = 250
|
54 |
+
wp: float = 0
|
55 |
+
wp0: float = 0.005 # initial lr ratio at the begging of lr warm up
|
56 |
+
wpe: float = 0.01 # final lr ratio at the end of training
|
57 |
+
sche: str = 'lin0' # lr schedule
|
58 |
+
|
59 |
+
opt: str = 'adamw' # lion: https://cloud.tencent.com/developer/article/2336657?areaId=106001 lr=5e-5 (0.25x) wd=0.8 (8x); Lion needs a large bs to work
|
60 |
+
afuse: bool = True # fused adamw
|
61 |
+
|
62 |
+
# other hps
|
63 |
+
saln: bool = False # whether to use shared adaln
|
64 |
+
anorm: bool = True # whether to use L2 normalized attention
|
65 |
+
fuse: bool = True # whether to use fused op like flash attn, xformers, fused MLP, fused LayerNorm, etc.
|
66 |
+
|
67 |
+
# data
|
68 |
+
pn: str = '1_2_3_4_5_6_8_10_13_16'
|
69 |
+
patch_size: int = 16
|
70 |
+
patch_nums: tuple = None # [automatically set; don't specify this] = tuple(map(int, args.pn.replace('-', '_').split('_')))
|
71 |
+
resos: tuple = None # [automatically set; don't specify this] = tuple(pn * args.patch_size for pn in args.patch_nums)
|
72 |
+
|
73 |
+
data_load_reso: int = None # [automatically set; don't specify this] would be max(patch_nums) * patch_size
|
74 |
+
mid_reso: float = 1.125 # aug: first resize to mid_reso = 1.125 * data_load_reso, then crop to data_load_reso
|
75 |
+
hflip: bool = False # augmentation: horizontal flip
|
76 |
+
workers: int = 0 # num workers; 0: auto, -1: don't use multiprocessing in DataLoader
|
77 |
+
|
78 |
+
# progressive training
|
79 |
+
pg: float = 0.0 # >0 for use progressive training during [0%, this] of training
|
80 |
+
pg0: int = 4 # progressive initial stage, 0: from the 1st token map, 1: from the 2nd token map, etc
|
81 |
+
pgwp: float = 0 # num of warmup epochs at each progressive stage
|
82 |
+
|
83 |
+
# would be automatically set in runtime
|
84 |
+
cmd: str = ' '.join(sys.argv[1:]) # [automatically set; don't specify this]
|
85 |
+
branch: str = subprocess.check_output(f'git symbolic-ref --short HEAD 2>/dev/null || git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
86 |
+
commit_id: str = subprocess.check_output(f'git rev-parse HEAD', shell=True).decode('utf-8').strip() or '[unknown]' # [automatically set; don't specify this]
|
87 |
+
commit_msg: str = (subprocess.check_output(f'git log -1', shell=True).decode('utf-8').strip().splitlines() or ['[unknown]'])[-1].strip() # [automatically set; don't specify this]
|
88 |
+
acc_mean: float = None # [automatically set; don't specify this]
|
89 |
+
acc_tail: float = None # [automatically set; don't specify this]
|
90 |
+
L_mean: float = None # [automatically set; don't specify this]
|
91 |
+
L_tail: float = None # [automatically set; don't specify this]
|
92 |
+
vacc_mean: float = None # [automatically set; don't specify this]
|
93 |
+
vacc_tail: float = None # [automatically set; don't specify this]
|
94 |
+
vL_mean: float = None # [automatically set; don't specify this]
|
95 |
+
vL_tail: float = None # [automatically set; don't specify this]
|
96 |
+
grad_norm: float = None # [automatically set; don't specify this]
|
97 |
+
cur_lr: float = None # [automatically set; don't specify this]
|
98 |
+
cur_wd: float = None # [automatically set; don't specify this]
|
99 |
+
cur_it: str = '' # [automatically set; don't specify this]
|
100 |
+
cur_ep: str = '' # [automatically set; don't specify this]
|
101 |
+
remain_time: str = '' # [automatically set; don't specify this]
|
102 |
+
finish_time: str = '' # [automatically set; don't specify this]
|
103 |
+
|
104 |
+
# environment
|
105 |
+
local_out_dir_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'local_output') # [automatically set; don't specify this]
|
106 |
+
tb_log_dir_path: str = '...tb-...' # [automatically set; don't specify this]
|
107 |
+
log_txt_path: str = '...' # [automatically set; don't specify this]
|
108 |
+
last_ckpt_path: str = '...' # [automatically set; don't specify this]
|
109 |
+
|
110 |
+
tf32: bool = True # whether to use TensorFloat32
|
111 |
+
device: str = 'cpu' # [automatically set; don't specify this]
|
112 |
+
seed: int = None # seed
|
113 |
+
def seed_everything(self, benchmark: bool):
|
114 |
+
torch.backends.cudnn.enabled = True
|
115 |
+
torch.backends.cudnn.benchmark = benchmark
|
116 |
+
if self.seed is None:
|
117 |
+
torch.backends.cudnn.deterministic = False
|
118 |
+
else:
|
119 |
+
torch.backends.cudnn.deterministic = True
|
120 |
+
seed = self.seed * dist.get_world_size() + dist.get_rank()
|
121 |
+
os.environ['PYTHONHASHSEED'] = str(seed)
|
122 |
+
random.seed(seed)
|
123 |
+
np.random.seed(seed)
|
124 |
+
torch.manual_seed(seed)
|
125 |
+
if torch.cuda.is_available():
|
126 |
+
torch.cuda.manual_seed(seed)
|
127 |
+
torch.cuda.manual_seed_all(seed)
|
128 |
+
same_seed_for_all_ranks: int = 0 # this is only for distributed sampler
|
129 |
+
def get_different_generator_for_each_rank(self) -> Optional[torch.Generator]: # for random augmentation
|
130 |
+
if self.seed is None:
|
131 |
+
return None
|
132 |
+
g = torch.Generator()
|
133 |
+
g.manual_seed(self.seed * dist.get_world_size() + dist.get_rank())
|
134 |
+
return g
|
135 |
+
|
136 |
+
local_debug: bool = 'KEVIN_LOCAL' in os.environ
|
137 |
+
dbg_nan: bool = False # 'KEVIN_LOCAL' in os.environ
|
138 |
+
|
139 |
+
def compile_model(self, m, fast):
|
140 |
+
if fast == 0 or self.local_debug:
|
141 |
+
return m
|
142 |
+
return torch.compile(m, mode={
|
143 |
+
1: 'reduce-overhead',
|
144 |
+
2: 'max-autotune',
|
145 |
+
3: 'default',
|
146 |
+
}[fast]) if hasattr(torch, 'compile') else m
|
147 |
+
|
148 |
+
def state_dict(self, key_ordered=True) -> Union[OrderedDict, dict]:
|
149 |
+
d = (OrderedDict if key_ordered else dict)()
|
150 |
+
# self.as_dict() would contain methods, but we only need variables
|
151 |
+
for k in self.class_variables.keys():
|
152 |
+
if k not in {'device'}: # these are not serializable
|
153 |
+
d[k] = getattr(self, k)
|
154 |
+
return d
|
155 |
+
|
156 |
+
def load_state_dict(self, d: Union[OrderedDict, dict, str]):
|
157 |
+
if isinstance(d, str): # for compatibility with old version
|
158 |
+
d: dict = eval('\n'.join([l for l in d.splitlines() if '<bound' not in l and 'device(' not in l]))
|
159 |
+
for k in d.keys():
|
160 |
+
try:
|
161 |
+
setattr(self, k, d[k])
|
162 |
+
except Exception as e:
|
163 |
+
print(f'k={k}, v={d[k]}')
|
164 |
+
raise e
|
165 |
+
|
166 |
+
@staticmethod
|
167 |
+
def set_tf32(tf32: bool):
|
168 |
+
if torch.cuda.is_available():
|
169 |
+
torch.backends.cudnn.allow_tf32 = bool(tf32)
|
170 |
+
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
|
171 |
+
if hasattr(torch, 'set_float32_matmul_precision'):
|
172 |
+
torch.set_float32_matmul_precision('high' if tf32 else 'highest')
|
173 |
+
print(f'[tf32] [precis] torch.get_float32_matmul_precision(): {torch.get_float32_matmul_precision()}')
|
174 |
+
print(f'[tf32] [ conv ] torch.backends.cudnn.allow_tf32: {torch.backends.cudnn.allow_tf32}')
|
175 |
+
print(f'[tf32] [matmul] torch.backends.cuda.matmul.allow_tf32: {torch.backends.cuda.matmul.allow_tf32}')
|
176 |
+
|
177 |
+
def dump_log(self):
|
178 |
+
if not dist.is_local_master():
|
179 |
+
return
|
180 |
+
if '1/' in self.cur_ep: # first time to dump log
|
181 |
+
with open(self.log_txt_path, 'w') as fp:
|
182 |
+
json.dump({'is_master': dist.is_master(), 'name': self.exp_name, 'cmd': self.cmd, 'commit': self.commit_id, 'branch': self.branch, 'tb_log_dir_path': self.tb_log_dir_path}, fp, indent=0)
|
183 |
+
fp.write('\n')
|
184 |
+
|
185 |
+
log_dict = {}
|
186 |
+
for k, v in {
|
187 |
+
'it': self.cur_it, 'ep': self.cur_ep,
|
188 |
+
'lr': self.cur_lr, 'wd': self.cur_wd, 'grad_norm': self.grad_norm,
|
189 |
+
'L_mean': self.L_mean, 'L_tail': self.L_tail, 'acc_mean': self.acc_mean, 'acc_tail': self.acc_tail,
|
190 |
+
'vL_mean': self.vL_mean, 'vL_tail': self.vL_tail, 'vacc_mean': self.vacc_mean, 'vacc_tail': self.vacc_tail,
|
191 |
+
'remain_time': self.remain_time, 'finish_time': self.finish_time,
|
192 |
+
}.items():
|
193 |
+
if hasattr(v, 'item'): v = v.item()
|
194 |
+
log_dict[k] = v
|
195 |
+
with open(self.log_txt_path, 'a') as fp:
|
196 |
+
fp.write(f'{log_dict}\n')
|
197 |
+
|
198 |
+
def __str__(self):
|
199 |
+
s = []
|
200 |
+
for k in self.class_variables.keys():
|
201 |
+
if k not in {'device', 'dbg_ks_fp'}: # these are not serializable
|
202 |
+
s.append(f' {k:20s}: {getattr(self, k)}')
|
203 |
+
s = '\n'.join(s)
|
204 |
+
return f'{{\n{s}\n}}\n'
|
205 |
+
|
206 |
+
|
207 |
+
def init_dist_and_get_args():
|
208 |
+
for i in range(len(sys.argv)):
|
209 |
+
if sys.argv[i].startswith('--local-rank=') or sys.argv[i].startswith('--local_rank='):
|
210 |
+
del sys.argv[i]
|
211 |
+
break
|
212 |
+
args = Args(explicit_bool=True).parse_args(known_only=True)
|
213 |
+
if args.local_debug:
|
214 |
+
args.pn = '1_2_3'
|
215 |
+
args.seed = 1
|
216 |
+
args.aln = 1e-2
|
217 |
+
args.alng = 1e-5
|
218 |
+
args.saln = False
|
219 |
+
args.afuse = False
|
220 |
+
args.pg = 0.8
|
221 |
+
args.pg0 = 1
|
222 |
+
else:
|
223 |
+
if args.data_path == '/path/to/imagenet':
|
224 |
+
raise ValueError(f'{"*"*40} please specify --data_path=/path/to/imagenet {"*"*40}')
|
225 |
+
|
226 |
+
# warn args.extra_args
|
227 |
+
if len(args.extra_args) > 0:
|
228 |
+
print(f'======================================================================================')
|
229 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================\n{args.extra_args}')
|
230 |
+
print(f'=========================== WARNING: UNEXPECTED EXTRA ARGS ===========================')
|
231 |
+
print(f'======================================================================================\n\n')
|
232 |
+
|
233 |
+
# init torch distributed
|
234 |
+
from utils import misc
|
235 |
+
os.makedirs(args.local_out_dir_path, exist_ok=True)
|
236 |
+
misc.init_distributed_mode(local_out_path=args.local_out_dir_path, timeout=30)
|
237 |
+
|
238 |
+
# set env
|
239 |
+
args.set_tf32(args.tf32)
|
240 |
+
args.seed_everything(benchmark=args.pg == 0)
|
241 |
+
|
242 |
+
# update args: data loading
|
243 |
+
args.device = dist.get_device()
|
244 |
+
if args.pn == '256':
|
245 |
+
args.pn = '1_2_3_4_5_6_8_10_13_16'
|
246 |
+
elif args.pn == '512':
|
247 |
+
args.pn = '1_2_3_4_6_9_13_18_24_32'
|
248 |
+
elif args.pn == '1024':
|
249 |
+
args.pn = '1_2_3_4_5_7_9_12_16_21_27_36_48_64'
|
250 |
+
args.patch_nums = tuple(map(int, args.pn.replace('-', '_').split('_')))
|
251 |
+
args.resos = tuple(pn * args.patch_size for pn in args.patch_nums)
|
252 |
+
args.data_load_reso = max(args.resos)
|
253 |
+
|
254 |
+
# update args: bs and lr
|
255 |
+
bs_per_gpu = round(args.bs / args.ac / dist.get_world_size())
|
256 |
+
args.batch_size = bs_per_gpu
|
257 |
+
args.bs = args.glb_batch_size = args.batch_size * dist.get_world_size()
|
258 |
+
args.workers = min(max(0, args.workers), args.batch_size)
|
259 |
+
|
260 |
+
args.tlr = args.ac * args.tblr * args.glb_batch_size / 256
|
261 |
+
args.twde = args.twde or args.twd
|
262 |
+
|
263 |
+
if args.wp == 0:
|
264 |
+
args.wp = args.ep * 1/50
|
265 |
+
|
266 |
+
# update args: progressive training
|
267 |
+
if args.pgwp == 0:
|
268 |
+
args.pgwp = args.ep * 1/300
|
269 |
+
if args.pg > 0:
|
270 |
+
args.sche = f'lin{args.pg:g}'
|
271 |
+
|
272 |
+
# update args: paths
|
273 |
+
args.log_txt_path = os.path.join(args.local_out_dir_path, 'log.txt')
|
274 |
+
args.last_ckpt_path = os.path.join(args.local_out_dir_path, f'ar-ckpt-last.pth')
|
275 |
+
_reg_valid_name = re.compile(r'[^\w\-+,.]')
|
276 |
+
tb_name = _reg_valid_name.sub(
|
277 |
+
'_',
|
278 |
+
f'tb-VARd{args.depth}'
|
279 |
+
f'__pn{args.pn}'
|
280 |
+
f'__b{args.bs}ep{args.ep}{args.opt[:4]}lr{args.tblr:g}wd{args.twd:g}'
|
281 |
+
)
|
282 |
+
args.tb_log_dir_path = os.path.join(args.local_out_dir_path, tb_name)
|
283 |
+
|
284 |
+
return args
|
utils/data.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os.path as osp
|
2 |
+
|
3 |
+
import PIL.Image as PImage
|
4 |
+
from torchvision.datasets.folder import DatasetFolder, IMG_EXTENSIONS
|
5 |
+
from torchvision.transforms import InterpolationMode, transforms
|
6 |
+
|
7 |
+
|
8 |
+
def normalize_01_into_pm1(x): # normalize x from [0, 1] to [-1, 1] by (x*2) - 1
|
9 |
+
return x.add(x).add_(-1)
|
10 |
+
|
11 |
+
|
12 |
+
def build_dataset(
|
13 |
+
data_path: str, final_reso: int,
|
14 |
+
hflip=False, mid_reso=1.125,
|
15 |
+
):
|
16 |
+
# build augmentations
|
17 |
+
mid_reso = round(mid_reso * final_reso) # first resize to mid_reso, then crop to final_reso
|
18 |
+
train_aug, val_aug = [
|
19 |
+
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
|
20 |
+
transforms.RandomCrop((final_reso, final_reso)),
|
21 |
+
transforms.ToTensor(), normalize_01_into_pm1,
|
22 |
+
], [
|
23 |
+
transforms.Resize(mid_reso, interpolation=InterpolationMode.LANCZOS), # transforms.Resize: resize the shorter edge to mid_reso
|
24 |
+
transforms.CenterCrop((final_reso, final_reso)),
|
25 |
+
transforms.ToTensor(), normalize_01_into_pm1,
|
26 |
+
]
|
27 |
+
if hflip: train_aug.insert(0, transforms.RandomHorizontalFlip())
|
28 |
+
train_aug, val_aug = transforms.Compose(train_aug), transforms.Compose(val_aug)
|
29 |
+
|
30 |
+
# build dataset
|
31 |
+
train_set = DatasetFolder(root=osp.join(data_path, 'train'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=train_aug)
|
32 |
+
val_set = DatasetFolder(root=osp.join(data_path, 'val'), loader=pil_loader, extensions=IMG_EXTENSIONS, transform=val_aug)
|
33 |
+
num_classes = 1000
|
34 |
+
print(f'[Dataset] {len(train_set)=}, {len(val_set)=}, {num_classes=}')
|
35 |
+
print_aug(train_aug, '[train]')
|
36 |
+
print_aug(val_aug, '[val]')
|
37 |
+
|
38 |
+
return num_classes, train_set, val_set
|
39 |
+
|
40 |
+
|
41 |
+
def pil_loader(path):
|
42 |
+
with open(path, 'rb') as f:
|
43 |
+
img: PImage.Image = PImage.open(f).convert('RGB')
|
44 |
+
return img
|
45 |
+
|
46 |
+
|
47 |
+
def print_aug(transform, label):
|
48 |
+
print(f'Transform {label} = ')
|
49 |
+
if hasattr(transform, 'transforms'):
|
50 |
+
for t in transform.transforms:
|
51 |
+
print(t)
|
52 |
+
else:
|
53 |
+
print(transform)
|
54 |
+
print('---------------------------\n')
|
utils/data_sampler.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch.utils.data.sampler import Sampler
|
4 |
+
|
5 |
+
|
6 |
+
class EvalDistributedSampler(Sampler):
|
7 |
+
def __init__(self, dataset, num_replicas, rank):
|
8 |
+
seps = np.linspace(0, len(dataset), num_replicas+1, dtype=int)
|
9 |
+
beg, end = seps[:-1], seps[1:]
|
10 |
+
beg, end = beg[rank], end[rank]
|
11 |
+
self.indices = tuple(range(beg, end))
|
12 |
+
|
13 |
+
def __iter__(self):
|
14 |
+
return iter(self.indices)
|
15 |
+
|
16 |
+
def __len__(self) -> int:
|
17 |
+
return len(self.indices)
|
18 |
+
|
19 |
+
|
20 |
+
class InfiniteBatchSampler(Sampler):
|
21 |
+
def __init__(self, dataset_len, batch_size, seed_for_all_rank=0, fill_last=False, shuffle=True, drop_last=False, start_ep=0, start_it=0):
|
22 |
+
self.dataset_len = dataset_len
|
23 |
+
self.batch_size = batch_size
|
24 |
+
self.iters_per_ep = dataset_len // batch_size if drop_last else (dataset_len + batch_size - 1) // batch_size
|
25 |
+
self.max_p = self.iters_per_ep * batch_size
|
26 |
+
self.fill_last = fill_last
|
27 |
+
self.shuffle = shuffle
|
28 |
+
self.epoch = start_ep
|
29 |
+
self.same_seed_for_all_ranks = seed_for_all_rank
|
30 |
+
self.indices = self.gener_indices()
|
31 |
+
self.start_ep, self.start_it = start_ep, start_it
|
32 |
+
|
33 |
+
def gener_indices(self):
|
34 |
+
if self.shuffle:
|
35 |
+
g = torch.Generator()
|
36 |
+
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
|
37 |
+
indices = torch.randperm(self.dataset_len, generator=g).numpy()
|
38 |
+
else:
|
39 |
+
indices = torch.arange(self.dataset_len).numpy()
|
40 |
+
|
41 |
+
tails = self.batch_size - (self.dataset_len % self.batch_size)
|
42 |
+
if tails != self.batch_size and self.fill_last:
|
43 |
+
tails = indices[:tails]
|
44 |
+
np.random.shuffle(indices)
|
45 |
+
indices = np.concatenate((indices, tails))
|
46 |
+
|
47 |
+
# built-in list/tuple is faster than np.ndarray (when collating the data via a for-loop)
|
48 |
+
# noinspection PyTypeChecker
|
49 |
+
return tuple(indices.tolist())
|
50 |
+
|
51 |
+
def __iter__(self):
|
52 |
+
self.epoch = self.start_ep
|
53 |
+
while True:
|
54 |
+
self.epoch += 1
|
55 |
+
p = (self.start_it * self.batch_size) if self.epoch == self.start_ep else 0
|
56 |
+
while p < self.max_p:
|
57 |
+
q = p + self.batch_size
|
58 |
+
yield self.indices[p:q]
|
59 |
+
p = q
|
60 |
+
if self.shuffle:
|
61 |
+
self.indices = self.gener_indices()
|
62 |
+
|
63 |
+
def __len__(self):
|
64 |
+
return self.iters_per_ep
|
65 |
+
|
66 |
+
|
67 |
+
class DistInfiniteBatchSampler(InfiniteBatchSampler):
|
68 |
+
def __init__(self, world_size, rank, dataset_len, glb_batch_size, same_seed_for_all_ranks=0, repeated_aug=0, fill_last=False, shuffle=True, start_ep=0, start_it=0):
|
69 |
+
assert glb_batch_size % world_size == 0
|
70 |
+
self.world_size, self.rank = world_size, rank
|
71 |
+
self.dataset_len = dataset_len
|
72 |
+
self.glb_batch_size = glb_batch_size
|
73 |
+
self.batch_size = glb_batch_size // world_size
|
74 |
+
|
75 |
+
self.iters_per_ep = (dataset_len + glb_batch_size - 1) // glb_batch_size
|
76 |
+
self.fill_last = fill_last
|
77 |
+
self.shuffle = shuffle
|
78 |
+
self.repeated_aug = repeated_aug
|
79 |
+
self.epoch = start_ep
|
80 |
+
self.same_seed_for_all_ranks = same_seed_for_all_ranks
|
81 |
+
self.indices = self.gener_indices()
|
82 |
+
self.start_ep, self.start_it = start_ep, start_it
|
83 |
+
|
84 |
+
def gener_indices(self):
|
85 |
+
global_max_p = self.iters_per_ep * self.glb_batch_size # global_max_p % world_size must be 0 cuz glb_batch_size % world_size == 0
|
86 |
+
# print(f'global_max_p = iters_per_ep({self.iters_per_ep}) * glb_batch_size({self.glb_batch_size}) = {global_max_p}')
|
87 |
+
if self.shuffle:
|
88 |
+
g = torch.Generator()
|
89 |
+
g.manual_seed(self.epoch + self.same_seed_for_all_ranks)
|
90 |
+
global_indices = torch.randperm(self.dataset_len, generator=g)
|
91 |
+
if self.repeated_aug > 1:
|
92 |
+
global_indices = global_indices[:(self.dataset_len + self.repeated_aug - 1) // self.repeated_aug].repeat_interleave(self.repeated_aug, dim=0)[:global_max_p]
|
93 |
+
else:
|
94 |
+
global_indices = torch.arange(self.dataset_len)
|
95 |
+
filling = global_max_p - global_indices.shape[0]
|
96 |
+
if filling > 0 and self.fill_last:
|
97 |
+
global_indices = torch.cat((global_indices, global_indices[:filling]))
|
98 |
+
# global_indices = tuple(global_indices.numpy().tolist())
|
99 |
+
|
100 |
+
seps = torch.linspace(0, global_indices.shape[0], self.world_size + 1, dtype=torch.int)
|
101 |
+
local_indices = global_indices[seps[self.rank].item():seps[self.rank + 1].item()].tolist()
|
102 |
+
self.max_p = len(local_indices)
|
103 |
+
return local_indices
|
utils/lr_control.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from pprint import pformat
|
3 |
+
from typing import Tuple, List, Dict, Union
|
4 |
+
|
5 |
+
import torch.nn
|
6 |
+
|
7 |
+
import dist
|
8 |
+
|
9 |
+
|
10 |
+
def lr_wd_annealing(sche_type: str, optimizer, peak_lr, wd, wd_end, cur_it, wp_it, max_it, wp0=0.005, wpe=0.001):
|
11 |
+
"""Decay the learning rate with half-cycle cosine after warmup"""
|
12 |
+
wp_it = round(wp_it)
|
13 |
+
|
14 |
+
if cur_it < wp_it:
|
15 |
+
cur_lr = wp0 + (1-wp0) * cur_it / wp_it
|
16 |
+
else:
|
17 |
+
pasd = (cur_it - wp_it) / (max_it-1 - wp_it) # [0, 1]
|
18 |
+
rest = 1 - pasd # [1, 0]
|
19 |
+
if sche_type == 'cos':
|
20 |
+
cur_lr = wpe + (1-wpe) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
21 |
+
elif sche_type == 'lin':
|
22 |
+
T = 0.15; max_rest = 1-T
|
23 |
+
if pasd < T: cur_lr = 1
|
24 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest # 1 to wpe
|
25 |
+
elif sche_type == 'lin0':
|
26 |
+
T = 0.05; max_rest = 1-T
|
27 |
+
if pasd < T: cur_lr = 1
|
28 |
+
else: cur_lr = wpe + (1-wpe) * rest / max_rest
|
29 |
+
elif sche_type == 'lin00':
|
30 |
+
cur_lr = wpe + (1-wpe) * rest
|
31 |
+
elif sche_type.startswith('lin'):
|
32 |
+
T = float(sche_type[3:]); max_rest = 1-T
|
33 |
+
wpe_mid = wpe + (1-wpe) * max_rest
|
34 |
+
wpe_mid = (1 + wpe_mid) / 2
|
35 |
+
if pasd < T: cur_lr = 1 + (wpe_mid-1) * pasd / T
|
36 |
+
else: cur_lr = wpe + (wpe_mid-wpe) * rest / max_rest
|
37 |
+
elif sche_type == 'exp':
|
38 |
+
T = 0.15; max_rest = 1-T
|
39 |
+
if pasd < T: cur_lr = 1
|
40 |
+
else:
|
41 |
+
expo = (pasd-T) / max_rest * math.log(wpe)
|
42 |
+
cur_lr = math.exp(expo)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError(f'unknown sche_type {sche_type}')
|
45 |
+
|
46 |
+
cur_lr *= peak_lr
|
47 |
+
pasd = cur_it / (max_it-1)
|
48 |
+
cur_wd = wd_end + (wd - wd_end) * (0.5 + 0.5 * math.cos(math.pi * pasd))
|
49 |
+
|
50 |
+
inf = 1e6
|
51 |
+
min_lr, max_lr = inf, -1
|
52 |
+
min_wd, max_wd = inf, -1
|
53 |
+
for param_group in optimizer.param_groups:
|
54 |
+
param_group['lr'] = cur_lr * param_group.get('lr_sc', 1) # 'lr_sc' could be assigned
|
55 |
+
max_lr = max(max_lr, param_group['lr'])
|
56 |
+
min_lr = min(min_lr, param_group['lr'])
|
57 |
+
|
58 |
+
param_group['weight_decay'] = cur_wd * param_group.get('wd_sc', 1)
|
59 |
+
max_wd = max(max_wd, param_group['weight_decay'])
|
60 |
+
if param_group['weight_decay'] > 0:
|
61 |
+
min_wd = min(min_wd, param_group['weight_decay'])
|
62 |
+
|
63 |
+
if min_lr == inf: min_lr = -1
|
64 |
+
if min_wd == inf: min_wd = -1
|
65 |
+
return min_lr, max_lr, min_wd, max_wd
|
66 |
+
|
67 |
+
|
68 |
+
def filter_params(model, nowd_keys=()) -> Tuple[
|
69 |
+
List[str], List[torch.nn.Parameter], List[Dict[str, Union[torch.nn.Parameter, float]]]
|
70 |
+
]:
|
71 |
+
para_groups, para_groups_dbg = {}, {}
|
72 |
+
names, paras = [], []
|
73 |
+
names_no_grad = []
|
74 |
+
count, numel = 0, 0
|
75 |
+
for name, para in model.named_parameters():
|
76 |
+
name = name.replace('_fsdp_wrapped_module.', '')
|
77 |
+
if not para.requires_grad:
|
78 |
+
names_no_grad.append(name)
|
79 |
+
continue # frozen weights
|
80 |
+
count += 1
|
81 |
+
numel += para.numel()
|
82 |
+
names.append(name)
|
83 |
+
paras.append(para)
|
84 |
+
|
85 |
+
if para.ndim == 1 or name.endswith('bias') or any(k in name for k in nowd_keys):
|
86 |
+
cur_wd_sc, group_name = 0., 'ND'
|
87 |
+
else:
|
88 |
+
cur_wd_sc, group_name = 1., 'D'
|
89 |
+
cur_lr_sc = 1.
|
90 |
+
if group_name not in para_groups:
|
91 |
+
para_groups[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
|
92 |
+
para_groups_dbg[group_name] = {'params': [], 'wd_sc': cur_wd_sc, 'lr_sc': cur_lr_sc}
|
93 |
+
para_groups[group_name]['params'].append(para)
|
94 |
+
para_groups_dbg[group_name]['params'].append(name)
|
95 |
+
|
96 |
+
for g in para_groups_dbg.values():
|
97 |
+
g['params'] = pformat(', '.join(g['params']), width=200)
|
98 |
+
|
99 |
+
print(f'[get_param_groups] param_groups = \n{pformat(para_groups_dbg, indent=2, width=240)}\n')
|
100 |
+
|
101 |
+
for rk in range(dist.get_world_size()):
|
102 |
+
dist.barrier()
|
103 |
+
if dist.get_rank() == rk:
|
104 |
+
print(f'[get_param_groups][rank{dist.get_rank()}] {type(model).__name__=} {count=}, {numel=}', flush=True, force=True)
|
105 |
+
print('')
|
106 |
+
|
107 |
+
assert len(names_no_grad) == 0, f'[get_param_groups] names_no_grad = \n{pformat(names_no_grad, indent=2, width=240)}\n'
|
108 |
+
return names, paras, list(para_groups.values())
|
utils/misc.py
ADDED
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import functools
|
3 |
+
import glob
|
4 |
+
import os
|
5 |
+
import subprocess
|
6 |
+
import sys
|
7 |
+
import time
|
8 |
+
from collections import defaultdict, deque
|
9 |
+
from typing import Iterator, List, Tuple
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import pytz
|
13 |
+
import torch
|
14 |
+
import torch.distributed as tdist
|
15 |
+
|
16 |
+
import dist
|
17 |
+
from utils import arg_util
|
18 |
+
|
19 |
+
os_system = functools.partial(subprocess.call, shell=True)
|
20 |
+
def echo(info):
|
21 |
+
os_system(f'echo "[$(date "+%m-%d-%H:%M:%S")] ({os.path.basename(sys._getframe().f_back.f_code.co_filename)}, line{sys._getframe().f_back.f_lineno})=> {info}"')
|
22 |
+
def os_system_get_stdout(cmd):
|
23 |
+
return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.decode('utf-8')
|
24 |
+
def os_system_get_stdout_stderr(cmd):
|
25 |
+
cnt = 0
|
26 |
+
while True:
|
27 |
+
try:
|
28 |
+
sp = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, timeout=30)
|
29 |
+
except subprocess.TimeoutExpired:
|
30 |
+
cnt += 1
|
31 |
+
print(f'[fetch free_port file] timeout cnt={cnt}')
|
32 |
+
else:
|
33 |
+
return sp.stdout.decode('utf-8'), sp.stderr.decode('utf-8')
|
34 |
+
|
35 |
+
|
36 |
+
def time_str(fmt='[%m-%d %H:%M:%S]'):
|
37 |
+
return datetime.datetime.now(tz=pytz.timezone('Asia/Shanghai')).strftime(fmt)
|
38 |
+
|
39 |
+
|
40 |
+
def init_distributed_mode(local_out_path, only_sync_master=False, timeout=30):
|
41 |
+
try:
|
42 |
+
dist.initialize(fork=False, timeout=timeout)
|
43 |
+
dist.barrier()
|
44 |
+
except RuntimeError:
|
45 |
+
print(f'{">"*75} NCCL Error {"<"*75}', flush=True)
|
46 |
+
time.sleep(10)
|
47 |
+
|
48 |
+
if local_out_path is not None: os.makedirs(local_out_path, exist_ok=True)
|
49 |
+
_change_builtin_print(dist.is_local_master())
|
50 |
+
if (dist.is_master() if only_sync_master else dist.is_local_master()) and local_out_path is not None and len(local_out_path):
|
51 |
+
sys.stdout, sys.stderr = SyncPrint(local_out_path, sync_stdout=True), SyncPrint(local_out_path, sync_stdout=False)
|
52 |
+
|
53 |
+
|
54 |
+
def _change_builtin_print(is_master):
|
55 |
+
import builtins as __builtin__
|
56 |
+
|
57 |
+
builtin_print = __builtin__.print
|
58 |
+
if type(builtin_print) != type(open):
|
59 |
+
return
|
60 |
+
|
61 |
+
def prt(*args, **kwargs):
|
62 |
+
force = kwargs.pop('force', False)
|
63 |
+
clean = kwargs.pop('clean', False)
|
64 |
+
deeper = kwargs.pop('deeper', False)
|
65 |
+
if is_master or force:
|
66 |
+
if not clean:
|
67 |
+
f_back = sys._getframe().f_back
|
68 |
+
if deeper and f_back.f_back is not None:
|
69 |
+
f_back = f_back.f_back
|
70 |
+
file_desc = f'{f_back.f_code.co_filename:24s}'[-24:]
|
71 |
+
builtin_print(f'{time_str()} ({file_desc}, line{f_back.f_lineno:-4d})=>', *args, **kwargs)
|
72 |
+
else:
|
73 |
+
builtin_print(*args, **kwargs)
|
74 |
+
|
75 |
+
__builtin__.print = prt
|
76 |
+
|
77 |
+
|
78 |
+
class SyncPrint(object):
|
79 |
+
def __init__(self, local_output_dir, sync_stdout=True):
|
80 |
+
self.sync_stdout = sync_stdout
|
81 |
+
self.terminal_stream = sys.stdout if sync_stdout else sys.stderr
|
82 |
+
fname = os.path.join(local_output_dir, 'stdout.txt' if sync_stdout else 'stderr.txt')
|
83 |
+
existing = os.path.exists(fname)
|
84 |
+
self.file_stream = open(fname, 'a')
|
85 |
+
if existing:
|
86 |
+
self.file_stream.write('\n'*7 + '='*55 + f' RESTART {time_str()} ' + '='*55 + '\n')
|
87 |
+
self.file_stream.flush()
|
88 |
+
self.enabled = True
|
89 |
+
|
90 |
+
def write(self, message):
|
91 |
+
self.terminal_stream.write(message)
|
92 |
+
self.file_stream.write(message)
|
93 |
+
|
94 |
+
def flush(self):
|
95 |
+
self.terminal_stream.flush()
|
96 |
+
self.file_stream.flush()
|
97 |
+
|
98 |
+
def close(self):
|
99 |
+
if not self.enabled:
|
100 |
+
return
|
101 |
+
self.enabled = False
|
102 |
+
self.file_stream.flush()
|
103 |
+
self.file_stream.close()
|
104 |
+
if self.sync_stdout:
|
105 |
+
sys.stdout = self.terminal_stream
|
106 |
+
sys.stdout.flush()
|
107 |
+
else:
|
108 |
+
sys.stderr = self.terminal_stream
|
109 |
+
sys.stderr.flush()
|
110 |
+
|
111 |
+
def __del__(self):
|
112 |
+
self.close()
|
113 |
+
|
114 |
+
|
115 |
+
class DistLogger(object):
|
116 |
+
def __init__(self, lg, verbose):
|
117 |
+
self._lg, self._verbose = lg, verbose
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def do_nothing(*args, **kwargs):
|
121 |
+
pass
|
122 |
+
|
123 |
+
def __getattr__(self, attr: str):
|
124 |
+
return getattr(self._lg, attr) if self._verbose else DistLogger.do_nothing
|
125 |
+
|
126 |
+
|
127 |
+
class TensorboardLogger(object):
|
128 |
+
def __init__(self, log_dir, filename_suffix):
|
129 |
+
try: import tensorflow_io as tfio
|
130 |
+
except: pass
|
131 |
+
from torch.utils.tensorboard import SummaryWriter
|
132 |
+
self.writer = SummaryWriter(log_dir=log_dir, filename_suffix=filename_suffix)
|
133 |
+
self.step = 0
|
134 |
+
|
135 |
+
def set_step(self, step=None):
|
136 |
+
if step is not None:
|
137 |
+
self.step = step
|
138 |
+
else:
|
139 |
+
self.step += 1
|
140 |
+
|
141 |
+
def update(self, head='scalar', step=None, **kwargs):
|
142 |
+
for k, v in kwargs.items():
|
143 |
+
if v is None:
|
144 |
+
continue
|
145 |
+
# assert isinstance(v, (float, int)), type(v)
|
146 |
+
if step is None: # iter wise
|
147 |
+
it = self.step
|
148 |
+
if it == 0 or (it + 1) % 500 == 0:
|
149 |
+
if hasattr(v, 'item'): v = v.item()
|
150 |
+
self.writer.add_scalar(f'{head}/{k}', v, it)
|
151 |
+
else: # epoch wise
|
152 |
+
if hasattr(v, 'item'): v = v.item()
|
153 |
+
self.writer.add_scalar(f'{head}/{k}', v, step)
|
154 |
+
|
155 |
+
def log_tensor_as_distri(self, tag, tensor1d, step=None):
|
156 |
+
if step is None: # iter wise
|
157 |
+
step = self.step
|
158 |
+
loggable = step == 0 or (step + 1) % 500 == 0
|
159 |
+
else: # epoch wise
|
160 |
+
loggable = True
|
161 |
+
if loggable:
|
162 |
+
try:
|
163 |
+
self.writer.add_histogram(tag=tag, values=tensor1d, global_step=step)
|
164 |
+
except Exception as e:
|
165 |
+
print(f'[log_tensor_as_distri writer.add_histogram failed]: {e}')
|
166 |
+
|
167 |
+
def log_image(self, tag, img_chw, step=None):
|
168 |
+
if step is None: # iter wise
|
169 |
+
step = self.step
|
170 |
+
loggable = step == 0 or (step + 1) % 500 == 0
|
171 |
+
else: # epoch wise
|
172 |
+
loggable = True
|
173 |
+
if loggable:
|
174 |
+
self.writer.add_image(tag, img_chw, step, dataformats='CHW')
|
175 |
+
|
176 |
+
def flush(self):
|
177 |
+
self.writer.flush()
|
178 |
+
|
179 |
+
def close(self):
|
180 |
+
self.writer.close()
|
181 |
+
|
182 |
+
|
183 |
+
class SmoothedValue(object):
|
184 |
+
"""Track a series of values and provide access to smoothed values over a
|
185 |
+
window or the global series average.
|
186 |
+
"""
|
187 |
+
|
188 |
+
def __init__(self, window_size=30, fmt=None):
|
189 |
+
if fmt is None:
|
190 |
+
fmt = "{median:.4f} ({global_avg:.4f})"
|
191 |
+
self.deque = deque(maxlen=window_size)
|
192 |
+
self.total = 0.0
|
193 |
+
self.count = 0
|
194 |
+
self.fmt = fmt
|
195 |
+
|
196 |
+
def update(self, value, n=1):
|
197 |
+
self.deque.append(value)
|
198 |
+
self.count += n
|
199 |
+
self.total += value * n
|
200 |
+
|
201 |
+
def synchronize_between_processes(self):
|
202 |
+
"""
|
203 |
+
Warning: does not synchronize the deque!
|
204 |
+
"""
|
205 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
|
206 |
+
tdist.barrier()
|
207 |
+
tdist.all_reduce(t)
|
208 |
+
t = t.tolist()
|
209 |
+
self.count = int(t[0])
|
210 |
+
self.total = t[1]
|
211 |
+
|
212 |
+
@property
|
213 |
+
def median(self):
|
214 |
+
return np.median(self.deque) if len(self.deque) else 0
|
215 |
+
|
216 |
+
@property
|
217 |
+
def avg(self):
|
218 |
+
return sum(self.deque) / (len(self.deque) or 1)
|
219 |
+
|
220 |
+
@property
|
221 |
+
def global_avg(self):
|
222 |
+
return self.total / (self.count or 1)
|
223 |
+
|
224 |
+
@property
|
225 |
+
def max(self):
|
226 |
+
return max(self.deque)
|
227 |
+
|
228 |
+
@property
|
229 |
+
def value(self):
|
230 |
+
return self.deque[-1] if len(self.deque) else 0
|
231 |
+
|
232 |
+
def time_preds(self, counts) -> Tuple[float, str, str]:
|
233 |
+
remain_secs = counts * self.median
|
234 |
+
return remain_secs, str(datetime.timedelta(seconds=round(remain_secs))), time.strftime("%Y-%m-%d %H:%M", time.localtime(time.time() + remain_secs))
|
235 |
+
|
236 |
+
def __str__(self):
|
237 |
+
return self.fmt.format(
|
238 |
+
median=self.median,
|
239 |
+
avg=self.avg,
|
240 |
+
global_avg=self.global_avg,
|
241 |
+
max=self.max,
|
242 |
+
value=self.value)
|
243 |
+
|
244 |
+
|
245 |
+
class MetricLogger(object):
|
246 |
+
def __init__(self, delimiter=' '):
|
247 |
+
self.meters = defaultdict(SmoothedValue)
|
248 |
+
self.delimiter = delimiter
|
249 |
+
self.iter_end_t = time.time()
|
250 |
+
self.log_iters = []
|
251 |
+
|
252 |
+
def update(self, **kwargs):
|
253 |
+
for k, v in kwargs.items():
|
254 |
+
if v is None:
|
255 |
+
continue
|
256 |
+
if hasattr(v, 'item'): v = v.item()
|
257 |
+
# assert isinstance(v, (float, int)), type(v)
|
258 |
+
assert isinstance(v, (float, int))
|
259 |
+
self.meters[k].update(v)
|
260 |
+
|
261 |
+
def __getattr__(self, attr):
|
262 |
+
if attr in self.meters:
|
263 |
+
return self.meters[attr]
|
264 |
+
if attr in self.__dict__:
|
265 |
+
return self.__dict__[attr]
|
266 |
+
raise AttributeError("'{}' object has no attribute '{}'".format(
|
267 |
+
type(self).__name__, attr))
|
268 |
+
|
269 |
+
def __str__(self):
|
270 |
+
loss_str = []
|
271 |
+
for name, meter in self.meters.items():
|
272 |
+
if len(meter.deque):
|
273 |
+
loss_str.append(
|
274 |
+
"{}: {}".format(name, str(meter))
|
275 |
+
)
|
276 |
+
return self.delimiter.join(loss_str)
|
277 |
+
|
278 |
+
def synchronize_between_processes(self):
|
279 |
+
for meter in self.meters.values():
|
280 |
+
meter.synchronize_between_processes()
|
281 |
+
|
282 |
+
def add_meter(self, name, meter):
|
283 |
+
self.meters[name] = meter
|
284 |
+
|
285 |
+
def log_every(self, start_it, max_iters, itrt, print_freq, header=None):
|
286 |
+
self.log_iters = set(np.linspace(0, max_iters-1, print_freq, dtype=int).tolist())
|
287 |
+
self.log_iters.add(start_it)
|
288 |
+
if not header:
|
289 |
+
header = ''
|
290 |
+
start_time = time.time()
|
291 |
+
self.iter_end_t = time.time()
|
292 |
+
self.iter_time = SmoothedValue(fmt='{avg:.4f}')
|
293 |
+
self.data_time = SmoothedValue(fmt='{avg:.4f}')
|
294 |
+
space_fmt = ':' + str(len(str(max_iters))) + 'd'
|
295 |
+
log_msg = [
|
296 |
+
header,
|
297 |
+
'[{0' + space_fmt + '}/{1}]',
|
298 |
+
'eta: {eta}',
|
299 |
+
'{meters}',
|
300 |
+
'time: {time}',
|
301 |
+
'data: {data}'
|
302 |
+
]
|
303 |
+
log_msg = self.delimiter.join(log_msg)
|
304 |
+
|
305 |
+
if isinstance(itrt, Iterator) and not hasattr(itrt, 'preload') and not hasattr(itrt, 'set_epoch'):
|
306 |
+
for i in range(start_it, max_iters):
|
307 |
+
obj = next(itrt)
|
308 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
309 |
+
yield i, obj
|
310 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
311 |
+
if i in self.log_iters:
|
312 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
313 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
314 |
+
print(log_msg.format(
|
315 |
+
i, max_iters, eta=eta_string,
|
316 |
+
meters=str(self),
|
317 |
+
time=str(self.iter_time), data=str(self.data_time)), flush=True)
|
318 |
+
self.iter_end_t = time.time()
|
319 |
+
else:
|
320 |
+
if isinstance(itrt, int): itrt = range(itrt)
|
321 |
+
for i, obj in enumerate(itrt):
|
322 |
+
self.data_time.update(time.time() - self.iter_end_t)
|
323 |
+
yield i, obj
|
324 |
+
self.iter_time.update(time.time() - self.iter_end_t)
|
325 |
+
if i in self.log_iters:
|
326 |
+
eta_seconds = self.iter_time.global_avg * (max_iters - i)
|
327 |
+
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
328 |
+
print(log_msg.format(
|
329 |
+
i, max_iters, eta=eta_string,
|
330 |
+
meters=str(self),
|
331 |
+
time=str(self.iter_time), data=str(self.data_time)), flush=True)
|
332 |
+
self.iter_end_t = time.time()
|
333 |
+
|
334 |
+
total_time = time.time() - start_time
|
335 |
+
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
336 |
+
print('{} Total time: {} ({:.3f} s / it)'.format(
|
337 |
+
header, total_time_str, total_time / max_iters), flush=True)
|
338 |
+
|
339 |
+
|
340 |
+
def glob_with_latest_modified_first(pattern, recursive=False):
|
341 |
+
return sorted(glob.glob(pattern, recursive=recursive), key=os.path.getmtime, reverse=True)
|
342 |
+
|
343 |
+
|
344 |
+
def auto_resume(args: arg_util.Args, pattern='ckpt*.pth') -> Tuple[List[str], int, int, dict, dict]:
|
345 |
+
info = []
|
346 |
+
file = os.path.join(args.local_out_dir_path, pattern)
|
347 |
+
all_ckpt = glob_with_latest_modified_first(file)
|
348 |
+
if len(all_ckpt) == 0:
|
349 |
+
info.append(f'[auto_resume] no ckpt found @ {file}')
|
350 |
+
info.append(f'[auto_resume quit]')
|
351 |
+
return info, 0, 0, {}, {}
|
352 |
+
else:
|
353 |
+
info.append(f'[auto_resume] load ckpt from @ {all_ckpt[0]} ...')
|
354 |
+
ckpt = torch.load(all_ckpt[0], map_location='cpu')
|
355 |
+
ep, it = ckpt['epoch'], ckpt['iter']
|
356 |
+
info.append(f'[auto_resume success] resume from ep{ep}, it{it}')
|
357 |
+
return info, ep, it, ckpt['trainer'], ckpt['args']
|
358 |
+
|
359 |
+
|
360 |
+
def create_npz_from_sample_folder(sample_folder: str):
|
361 |
+
"""
|
362 |
+
Builds a single .npz file from a folder of .png samples. Refer to DiT.
|
363 |
+
"""
|
364 |
+
import os, glob
|
365 |
+
import numpy as np
|
366 |
+
from tqdm import tqdm
|
367 |
+
from PIL import Image
|
368 |
+
|
369 |
+
samples = []
|
370 |
+
pngs = glob.glob(os.path.join(sample_folder, '*.png')) + glob.glob(os.path.join(sample_folder, '*.PNG'))
|
371 |
+
assert len(pngs) == 50_000, f'{len(pngs)} png files found in {sample_folder}, but expected 50,000'
|
372 |
+
for png in tqdm(pngs, desc='Building .npz file from samples (png only)'):
|
373 |
+
with Image.open(png) as sample_pil:
|
374 |
+
sample_np = np.asarray(sample_pil).astype(np.uint8)
|
375 |
+
samples.append(sample_np)
|
376 |
+
samples = np.stack(samples)
|
377 |
+
assert samples.shape == (50_000, samples.shape[1], samples.shape[2], 3)
|
378 |
+
npz_path = f'{sample_folder}.npz'
|
379 |
+
np.savez(npz_path, arr_0=samples)
|
380 |
+
print(f'Saved .npz file to {npz_path} [shape={samples.shape}].')
|
381 |
+
return npz_path
|