init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +30 -0
- app.py +267 -0
- models/__init__.py +63 -0
- models/__pycache__/__init__.cpython-37.pyc +0 -0
- models/__pycache__/__init__.cpython-39.pyc +0 -0
- models/__pycache__/bev3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/bev3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/eg3d_generator_fv.cpython-37.pyc +0 -0
- models/__pycache__/eg3d_generator_fv.cpython-39.pyc +0 -0
- models/__pycache__/ghfeat_encoder.cpython-37.pyc +0 -0
- models/__pycache__/ghfeat_encoder.cpython-39.pyc +0 -0
- models/__pycache__/inception_model.cpython-37.pyc +0 -0
- models/__pycache__/inception_model.cpython-39.pyc +0 -0
- models/__pycache__/perceptual_model.cpython-37.pyc +0 -0
- models/__pycache__/perceptual_model.cpython-39.pyc +0 -0
- models/__pycache__/pggan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/pggan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/pggan_generator.cpython-37.pyc +0 -0
- models/__pycache__/pggan_generator.cpython-39.pyc +0 -0
- models/__pycache__/pigan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/pigan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/pigan_generator.cpython-37.pyc +0 -0
- models/__pycache__/pigan_generator.cpython-39.pyc +0 -0
- models/__pycache__/sgbev3d_generator.cpython-37.pyc +0 -0
- models/__pycache__/sgbev3d_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan2_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan2_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan2_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan2_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan3_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan3_generator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/stylegan_generator.cpython-37.pyc +0 -0
- models/__pycache__/stylegan_generator.cpython-39.pyc +0 -0
- models/__pycache__/volumegan_discriminator.cpython-37.pyc +0 -0
- models/__pycache__/volumegan_discriminator.cpython-39.pyc +0 -0
- models/__pycache__/volumegan_generator.cpython-37.pyc +0 -0
- models/__pycache__/volumegan_generator.cpython-39.pyc +0 -0
- models/bev3d_generator.py +301 -0
- models/eg3d_discriminator.py +243 -0
- models/eg3d_generator.py +315 -0
- models/eg3d_generator_fv.py +320 -0
- models/ghfeat_encoder.py +563 -0
- models/inception_model.py +562 -0
- models/perceptual_model.py +519 -0
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM nvidia/cuda:11.1.0-devel-ubuntu22.04
|
2 |
+
|
3 |
+
ENV CUDA_HOME=/usr/local/cuda
|
4 |
+
ENV PATH=${CUDA_HOME}/bin:/home/${USER_NAME}/.local/bin:${PATH}
|
5 |
+
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
6 |
+
ENV LIBRARY_PATH=${CUDA_HOME}/lib64/stubs:${LIBRARY_PATH}
|
7 |
+
|
8 |
+
# apt install by root user
|
9 |
+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
10 |
+
build-essential \
|
11 |
+
curl \
|
12 |
+
git \
|
13 |
+
python-is-python3 \
|
14 |
+
python3.7-dev \
|
15 |
+
python3-pip \
|
16 |
+
wget \
|
17 |
+
&& rm -rf /var/lib/apt/lists/*
|
18 |
+
|
19 |
+
RUN pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
|
20 |
+
|
21 |
+
|
22 |
+
WORKDIR /code
|
23 |
+
|
24 |
+
COPY ./requirements.txt /code/requirements.txt
|
25 |
+
|
26 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
27 |
+
|
28 |
+
COPY . .
|
29 |
+
|
30 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "7860"]
|
app.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from models import build_model
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import ninja
|
7 |
+
import torch
|
8 |
+
from tqdm import trange
|
9 |
+
import imageio
|
10 |
+
|
11 |
+
checkpoint = '/mnt/petrelfs/zhangqihang/data/berfscene_clevr.pth'
|
12 |
+
state = torch.load(checkpoint, map_location='cpu')
|
13 |
+
G = build_model(**state['model_kwargs_init']['generator_smooth'])
|
14 |
+
o0, o1 = G.load_state_dict(state['models']['generator_smooth'], strict=False)
|
15 |
+
G.eval().cuda()
|
16 |
+
G.backbone.synthesis.input.x_offset =0
|
17 |
+
G.backbone.synthesis.input.y_offset =0
|
18 |
+
G_kwargs= dict(noise_mode='const',
|
19 |
+
fused_modulate=False,
|
20 |
+
impl='cuda',
|
21 |
+
fp16_res=None)
|
22 |
+
|
23 |
+
def trans(x, y, z, length):
|
24 |
+
w = h = length
|
25 |
+
x = 0.5 * w - 128 + 256 - (x/9 + .5) * 256
|
26 |
+
y = 0.5 * h - 128 + (y/9 + .5) * 256
|
27 |
+
z = z / 9 * 256
|
28 |
+
return x, y, z
|
29 |
+
def get_bev_from_objs(objs, length=256, scale = 6):
|
30 |
+
h, w = length, length *scale
|
31 |
+
nc = 14
|
32 |
+
canvas = np.zeros([h, w, nc])
|
33 |
+
xx = np.ones([h,w]).cumsum(0)
|
34 |
+
yy = np.ones([h,w]).cumsum(1)
|
35 |
+
|
36 |
+
for x, y, z, shape, color, material, rot in objs:
|
37 |
+
y, x, z = trans(x, y, z, length)
|
38 |
+
|
39 |
+
feat = [0] * nc
|
40 |
+
feat[0] = 1
|
41 |
+
feat[COLOR_NAME_LIST.index(color) + 1] = 1
|
42 |
+
feat[SHAPE_NAME_LIST.index(shape) + 1 + len(COLOR_NAME_LIST)] = 1
|
43 |
+
feat[MATERIAL_NAME_LIST.index(material) + 1 + len(COLOR_NAME_LIST) + len(SHAPE_NAME_LIST)] = 1
|
44 |
+
feat = np.array(feat)
|
45 |
+
rot_sin = np.sin(rot / 180 * np.pi)
|
46 |
+
rot_cos = np.cos(rot / 180 * np.pi)
|
47 |
+
|
48 |
+
if shape == 'cube':
|
49 |
+
mask = (np.abs(+rot_cos * (xx-x) + rot_sin * (yy-y)) <= z) * \
|
50 |
+
(np.abs(-rot_sin * (xx-x) + rot_cos * (yy-y)) <= z)
|
51 |
+
else:
|
52 |
+
mask = ((xx-x)**2 + (y-yy)**2) ** 0.5 <= z
|
53 |
+
canvas[mask] = feat
|
54 |
+
canvas = np.transpose(canvas, [2, 0, 1]).astype(np.float32)
|
55 |
+
rotate_angle = 0
|
56 |
+
canvas = torchvision.transforms.functional.rotate(torch.tensor(canvas), rotate_angle).numpy()
|
57 |
+
return canvas
|
58 |
+
|
59 |
+
# COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'brown', 'blue']
|
60 |
+
COLOR_NAME_LIST = ['cyan', 'green', 'purple', 'red', 'yellow', 'gray', 'purple', 'blue']
|
61 |
+
SHAPE_NAME_LIST = ['cube', 'sphere', 'cylinder']
|
62 |
+
MATERIAL_NAME_LIST = ['rubber', 'metal']
|
63 |
+
|
64 |
+
xy_lib = dict()
|
65 |
+
xy_lib['B'] = [
|
66 |
+
[-2, -1],
|
67 |
+
[-1, -1],
|
68 |
+
[-2, 0],
|
69 |
+
[-2, 1],
|
70 |
+
[-1, .5],
|
71 |
+
[0, 1],
|
72 |
+
[0, 0],
|
73 |
+
[0, -1],
|
74 |
+
[0, 2],
|
75 |
+
[-1, 2],
|
76 |
+
[-2, 2]
|
77 |
+
]
|
78 |
+
xy_lib['B'] = [
|
79 |
+
[-2.5, 1.25],
|
80 |
+
[-2, 2],
|
81 |
+
[-2, 0.5],
|
82 |
+
[-2, -0.75],
|
83 |
+
[-1, -1],
|
84 |
+
[-1, 2],
|
85 |
+
[-1, 0],
|
86 |
+
[-1, 2],
|
87 |
+
[0, 1],
|
88 |
+
[0, 0],
|
89 |
+
[0, -1],
|
90 |
+
[0, 2],
|
91 |
+
# [-1, 2],
|
92 |
+
|
93 |
+
]
|
94 |
+
xy_lib['B'] = [
|
95 |
+
[-2.5, 1.25],
|
96 |
+
[-2, 2],
|
97 |
+
[-2, 0.5],
|
98 |
+
[-2, -1],
|
99 |
+
[-1, -1.25],
|
100 |
+
[-1, 2],
|
101 |
+
[-1, 0],
|
102 |
+
[-1, 2],
|
103 |
+
[0, 1],
|
104 |
+
[0, 0],
|
105 |
+
[0, -1.25],
|
106 |
+
[0, 2],
|
107 |
+
# [-1, 2],
|
108 |
+
|
109 |
+
]
|
110 |
+
xy_lib['R'] = [
|
111 |
+
[0, -1],
|
112 |
+
[0, 0],
|
113 |
+
[0, 1],
|
114 |
+
[0, 2],
|
115 |
+
[-1, -1],
|
116 |
+
# [-1, 2],
|
117 |
+
[-2, -1],
|
118 |
+
[-2, 0],
|
119 |
+
[-2.25, 2],
|
120 |
+
[-1, 1]
|
121 |
+
]
|
122 |
+
xy_lib['C'] = [
|
123 |
+
[0, -1],
|
124 |
+
[0, 0],
|
125 |
+
[0, 1],
|
126 |
+
[0, 2],
|
127 |
+
[-1, -1],
|
128 |
+
[-1, 2],
|
129 |
+
[-2, -1],
|
130 |
+
# [-2, .5],
|
131 |
+
[-2, 2],
|
132 |
+
# [-1, .5]
|
133 |
+
]
|
134 |
+
xy_lib['s'] = [
|
135 |
+
[0, -1],
|
136 |
+
[0, 0],
|
137 |
+
[0, 2],
|
138 |
+
[-1, -1],
|
139 |
+
[-1, 2],
|
140 |
+
[-2, -1],
|
141 |
+
[-2, 1],
|
142 |
+
[-2, 2],
|
143 |
+
[-1, .5]
|
144 |
+
]
|
145 |
+
|
146 |
+
xy_lib['F'] = [
|
147 |
+
[0, -1],
|
148 |
+
[0, 0],
|
149 |
+
[0, 1],
|
150 |
+
[0, 2],
|
151 |
+
[-1, -1],
|
152 |
+
# [-1, 2],
|
153 |
+
[-2, -1],
|
154 |
+
[-2, .5],
|
155 |
+
# [-2, 2],
|
156 |
+
[-1, .5]
|
157 |
+
]
|
158 |
+
|
159 |
+
xy_lib['c'] = [
|
160 |
+
[0.8,1],
|
161 |
+
# [-0.8,1],
|
162 |
+
[0,0.1],
|
163 |
+
[0,1.9],
|
164 |
+
]
|
165 |
+
|
166 |
+
xy_lib['e'] = [
|
167 |
+
[0, -1],
|
168 |
+
[0, 0],
|
169 |
+
[0, 1],
|
170 |
+
[0, 2],
|
171 |
+
[-1, -1],
|
172 |
+
[-1, 2],
|
173 |
+
[-2, -1],
|
174 |
+
[-2, .5],
|
175 |
+
[-2, 2],
|
176 |
+
[-1, .5]
|
177 |
+
]
|
178 |
+
xy_lib['n'] = [
|
179 |
+
[0,1],
|
180 |
+
[0,-1],
|
181 |
+
[0,0.1],
|
182 |
+
[0,1.9],
|
183 |
+
[-1,0],
|
184 |
+
[-2,1],
|
185 |
+
[-3,-1],
|
186 |
+
[-3,1],
|
187 |
+
[-3,0.1],
|
188 |
+
[-3,1.9],
|
189 |
+
]
|
190 |
+
offset_x = dict(B=4, R=4, C=4, F=4, c=3, s=4, e=4, n=4.8)
|
191 |
+
s = 'BeRFsCene'
|
192 |
+
objs = []
|
193 |
+
offset = 2
|
194 |
+
for idx, c in enumerate(s):
|
195 |
+
xy = xy_lib[c]
|
196 |
+
|
197 |
+
|
198 |
+
color = np.random.choice(COLOR_NAME_LIST)
|
199 |
+
for i in range(len(xy)):
|
200 |
+
# while 1:
|
201 |
+
# is_ok = 1
|
202 |
+
# x, y =
|
203 |
+
|
204 |
+
# for prev_x, prev_y in zip(xpool, ypool):
|
205 |
+
x, y = xy[i]
|
206 |
+
y *= 1.5
|
207 |
+
y -= 0.5
|
208 |
+
x -= offset
|
209 |
+
z = 0.35
|
210 |
+
# if idx<4:
|
211 |
+
# color = np.random.choice(COLOR_NAME_LIST[:-1])
|
212 |
+
# else:
|
213 |
+
# color = 'blue'
|
214 |
+
shape = 'cube'
|
215 |
+
material = 'rubber'
|
216 |
+
rot = 0
|
217 |
+
objs.append([x, y, z, shape, color, material, rot])
|
218 |
+
offset += offset_x[c]
|
219 |
+
Image.fromarray((255 * .8 - get_bev_from_objs(objs)[0] *.8 * 255).astype(np.uint8))
|
220 |
+
|
221 |
+
batch_size = 1
|
222 |
+
code = torch.randn(1, G.z_dim).cuda()
|
223 |
+
to_pil = torchvision.transforms.ToPILImage()
|
224 |
+
large_bevs = torch.tensor(get_bev_from_objs(objs)).cuda()[None]
|
225 |
+
bevs = large_bevs[..., 0: 0+256]
|
226 |
+
RT = torch.tensor([[ -1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.5000, -0.8660,
|
227 |
+
10.3923, 0.0000, -0.8660, -0.5000, 6.0000, 0.0000, 0.0000,
|
228 |
+
0.0000, 1.0000, 262.5000, 0.0000, 32.0000, 0.0000, 262.5000,
|
229 |
+
32.0000, 0.0000, 0.0000, 1.0000]], device='cuda')
|
230 |
+
|
231 |
+
print('prepare finish', flush=True)
|
232 |
+
|
233 |
+
def inference(name):
|
234 |
+
print('inference', name, flush=True)
|
235 |
+
gen = G(code, RT, bevs)
|
236 |
+
rgb = gen['gen_output']['image'][0] * .5 + .5
|
237 |
+
print('inference', name, flush=True)
|
238 |
+
return np.array(to_pil(rgb))
|
239 |
+
|
240 |
+
# to_pil(rgb).save('tmp.png')
|
241 |
+
# save_path = '/mnt/petrelfs/zhangqihang/code/3d-scene-gen/tmp.png'
|
242 |
+
# return [save_path]
|
243 |
+
|
244 |
+
with gr.Blocks() as demo:
|
245 |
+
gr.HTML(
|
246 |
+
"""
|
247 |
+
abc
|
248 |
+
""")
|
249 |
+
|
250 |
+
with gr.Group():
|
251 |
+
with gr.Row():
|
252 |
+
with gr.Column():
|
253 |
+
with gr.Row():
|
254 |
+
with gr.Column():
|
255 |
+
with gr.Row():
|
256 |
+
num_frames = gr.Dropdown(["24 - frames", "32 - frames", "40 - frames", "48 - frames", "56 - frames", "80 - recommended to run on local GPUs", "240 - recommended to run on local GPUs", "600 - recommended to run on local GPUs", "1200 - recommended to run on local GPUs", "10000 - recommended to run on local GPUs"], label="Number of Video Frames", info="For >56 frames use local workstation!", value="24 - frames")
|
257 |
+
|
258 |
+
with gr.Row():
|
259 |
+
with gr.Row():
|
260 |
+
btn = gr.Button("Result")
|
261 |
+
|
262 |
+
gallery = gr.Image(label='img', show_label=True, elem_id="gallery")
|
263 |
+
|
264 |
+
btn.click(fn=inference, inputs=num_frames, outputs=[gallery], postprocess=False)
|
265 |
+
|
266 |
+
demo.queue()
|
267 |
+
demo.launch(server_name='0.0.0.0', server_port=10093, debug=True, show_error=True)
|
models/__init__.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Collects all models."""
|
3 |
+
|
4 |
+
from .pggan_generator import PGGANGenerator
|
5 |
+
from .pggan_discriminator import PGGANDiscriminator
|
6 |
+
from .stylegan_generator import StyleGANGenerator
|
7 |
+
from .stylegan_discriminator import StyleGANDiscriminator
|
8 |
+
from .stylegan2_generator import StyleGAN2Generator
|
9 |
+
from .stylegan2_discriminator import StyleGAN2Discriminator
|
10 |
+
from .stylegan3_generator import StyleGAN3Generator
|
11 |
+
from .ghfeat_encoder import GHFeatEncoder
|
12 |
+
from .perceptual_model import PerceptualModel
|
13 |
+
from .inception_model import InceptionModel
|
14 |
+
from .eg3d_generator import EG3DGenerator
|
15 |
+
from .eg3d_discriminator import DualDiscriminator
|
16 |
+
from .pigan_generator import PiGANGenerator
|
17 |
+
from .pigan_discriminator import PiGANDiscriminator
|
18 |
+
from .volumegan_generator import VolumeGANGenerator
|
19 |
+
from .volumegan_discriminator import VolumeGANDiscriminator
|
20 |
+
from .eg3d_generator_fv import EG3DGeneratorFV
|
21 |
+
from .bev3d_generator import BEV3DGenerator
|
22 |
+
from .sgbev3d_generator import SGBEV3DGenerator
|
23 |
+
|
24 |
+
__all__ = ['build_model']
|
25 |
+
|
26 |
+
_MODELS = {
|
27 |
+
'PGGANGenerator': PGGANGenerator,
|
28 |
+
'PGGANDiscriminator': PGGANDiscriminator,
|
29 |
+
'StyleGANGenerator': StyleGANGenerator,
|
30 |
+
'StyleGANDiscriminator': StyleGANDiscriminator,
|
31 |
+
'StyleGAN2Generator': StyleGAN2Generator,
|
32 |
+
'StyleGAN2Discriminator': StyleGAN2Discriminator,
|
33 |
+
'StyleGAN3Generator': StyleGAN3Generator,
|
34 |
+
'GHFeatEncoder': GHFeatEncoder,
|
35 |
+
'PerceptualModel': PerceptualModel.build_model,
|
36 |
+
'InceptionModel': InceptionModel.build_model,
|
37 |
+
'EG3DGenerator': EG3DGenerator,
|
38 |
+
'EG3DDiscriminator': DualDiscriminator,
|
39 |
+
'PiGANGenerator': PiGANGenerator,
|
40 |
+
'PiGANDiscriminator': PiGANDiscriminator,
|
41 |
+
'VolumeGANGenerator': VolumeGANGenerator,
|
42 |
+
'VolumeGANDiscriminator': VolumeGANDiscriminator,
|
43 |
+
'EG3DGeneratorFV': EG3DGeneratorFV,
|
44 |
+
'BEV3DGenerator': BEV3DGenerator,
|
45 |
+
'SGBEV3DGenerator': SGBEV3DGenerator,
|
46 |
+
}
|
47 |
+
|
48 |
+
|
49 |
+
def build_model(model_type, **kwargs):
|
50 |
+
"""Builds a model based on its class type.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
model_type: Class type to which the model belongs, which is case
|
54 |
+
sensitive.
|
55 |
+
**kwargs: Additional arguments to build the model.
|
56 |
+
|
57 |
+
Raises:
|
58 |
+
ValueError: If the `model_type` is not supported.
|
59 |
+
"""
|
60 |
+
if model_type not in _MODELS:
|
61 |
+
raise ValueError(f'Invalid model type: `{model_type}`!\n'
|
62 |
+
f'Types allowed: {list(_MODELS)}.')
|
63 |
+
return _MODELS[model_type](**kwargs)
|
models/__pycache__/__init__.cpython-37.pyc
ADDED
Binary file (2.06 kB). View file
|
|
models/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (2.08 kB). View file
|
|
models/__pycache__/bev3d_generator.cpython-37.pyc
ADDED
Binary file (6.16 kB). View file
|
|
models/__pycache__/bev3d_generator.cpython-39.pyc
ADDED
Binary file (6.07 kB). View file
|
|
models/__pycache__/eg3d_discriminator.cpython-37.pyc
ADDED
Binary file (8.01 kB). View file
|
|
models/__pycache__/eg3d_discriminator.cpython-39.pyc
ADDED
Binary file (7.73 kB). View file
|
|
models/__pycache__/eg3d_generator.cpython-37.pyc
ADDED
Binary file (6.21 kB). View file
|
|
models/__pycache__/eg3d_generator.cpython-39.pyc
ADDED
Binary file (6.3 kB). View file
|
|
models/__pycache__/eg3d_generator_fv.cpython-37.pyc
ADDED
Binary file (6.35 kB). View file
|
|
models/__pycache__/eg3d_generator_fv.cpython-39.pyc
ADDED
Binary file (6.43 kB). View file
|
|
models/__pycache__/ghfeat_encoder.cpython-37.pyc
ADDED
Binary file (14.3 kB). View file
|
|
models/__pycache__/ghfeat_encoder.cpython-39.pyc
ADDED
Binary file (14.1 kB). View file
|
|
models/__pycache__/inception_model.cpython-37.pyc
ADDED
Binary file (16 kB). View file
|
|
models/__pycache__/inception_model.cpython-39.pyc
ADDED
Binary file (15.7 kB). View file
|
|
models/__pycache__/perceptual_model.cpython-37.pyc
ADDED
Binary file (14.3 kB). View file
|
|
models/__pycache__/perceptual_model.cpython-39.pyc
ADDED
Binary file (14 kB). View file
|
|
models/__pycache__/pggan_discriminator.cpython-37.pyc
ADDED
Binary file (12 kB). View file
|
|
models/__pycache__/pggan_discriminator.cpython-39.pyc
ADDED
Binary file (11.9 kB). View file
|
|
models/__pycache__/pggan_generator.cpython-37.pyc
ADDED
Binary file (10.6 kB). View file
|
|
models/__pycache__/pggan_generator.cpython-39.pyc
ADDED
Binary file (10.6 kB). View file
|
|
models/__pycache__/pigan_discriminator.cpython-37.pyc
ADDED
Binary file (8.32 kB). View file
|
|
models/__pycache__/pigan_discriminator.cpython-39.pyc
ADDED
Binary file (8.31 kB). View file
|
|
models/__pycache__/pigan_generator.cpython-37.pyc
ADDED
Binary file (12.7 kB). View file
|
|
models/__pycache__/pigan_generator.cpython-39.pyc
ADDED
Binary file (12.8 kB). View file
|
|
models/__pycache__/sgbev3d_generator.cpython-37.pyc
ADDED
Binary file (7.01 kB). View file
|
|
models/__pycache__/sgbev3d_generator.cpython-39.pyc
ADDED
Binary file (7.04 kB). View file
|
|
models/__pycache__/stylegan2_discriminator.cpython-37.pyc
ADDED
Binary file (17.7 kB). View file
|
|
models/__pycache__/stylegan2_discriminator.cpython-39.pyc
ADDED
Binary file (17.7 kB). View file
|
|
models/__pycache__/stylegan2_generator.cpython-37.pyc
ADDED
Binary file (32.9 kB). View file
|
|
models/__pycache__/stylegan2_generator.cpython-39.pyc
ADDED
Binary file (32.9 kB). View file
|
|
models/__pycache__/stylegan3_generator.cpython-37.pyc
ADDED
Binary file (35.8 kB). View file
|
|
models/__pycache__/stylegan3_generator.cpython-39.pyc
ADDED
Binary file (35.7 kB). View file
|
|
models/__pycache__/stylegan_discriminator.cpython-37.pyc
ADDED
Binary file (15.9 kB). View file
|
|
models/__pycache__/stylegan_discriminator.cpython-39.pyc
ADDED
Binary file (15.9 kB). View file
|
|
models/__pycache__/stylegan_generator.cpython-37.pyc
ADDED
Binary file (24.9 kB). View file
|
|
models/__pycache__/stylegan_generator.cpython-39.pyc
ADDED
Binary file (24.9 kB). View file
|
|
models/__pycache__/volumegan_discriminator.cpython-37.pyc
ADDED
Binary file (17.8 kB). View file
|
|
models/__pycache__/volumegan_discriminator.cpython-39.pyc
ADDED
Binary file (17.8 kB). View file
|
|
models/__pycache__/volumegan_generator.cpython-37.pyc
ADDED
Binary file (18.2 kB). View file
|
|
models/__pycache__/volumegan_generator.cpython-39.pyc
ADDED
Binary file (18.2 kB). View file
|
|
models/bev3d_generator.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.8
|
2 |
+
"""Contains the implementation of generator described in BEV3D."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
|
7 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
8 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X_conststyle
|
11 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
12 |
+
from models.rendering.renderer import Renderer
|
13 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
14 |
+
|
15 |
+
from models.utils.spade import SPADEGenerator
|
16 |
+
|
17 |
+
class BEV3DGenerator(nn.Module):
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
z_dim,
|
22 |
+
semantic_nc,
|
23 |
+
ngf,
|
24 |
+
bev_grid_size,
|
25 |
+
aspect_ratio,
|
26 |
+
num_upsampling_layers,
|
27 |
+
not_use_vae,
|
28 |
+
norm_G,
|
29 |
+
img_resolution,
|
30 |
+
interpolate_sr,
|
31 |
+
segmask=False,
|
32 |
+
dim_seq='16,8,4,2,1',
|
33 |
+
xyz_pe=False,
|
34 |
+
hidden_dim=64,
|
35 |
+
additional_layer_num=0,
|
36 |
+
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
|
37 |
+
rendering_kwargs={}, # Arguments for rendering.
|
38 |
+
sr_kwargs={}, # Arguments for SuperResolution Network.
|
39 |
+
):
|
40 |
+
super().__init__()
|
41 |
+
|
42 |
+
self.z_dim = z_dim
|
43 |
+
self.interpolate_sr = interpolate_sr
|
44 |
+
self.segmask = segmask
|
45 |
+
|
46 |
+
# Set up the overall renderer.
|
47 |
+
self.renderer = Renderer()
|
48 |
+
|
49 |
+
# Set up the feature extractor.
|
50 |
+
self.feature_extractor = FeatureExtractor(ref_mode='bev_plane_clevr', xyz_pe=xyz_pe)
|
51 |
+
|
52 |
+
# Set up the reference representation generator.
|
53 |
+
self.backbone = SPADEGenerator(z_dim=z_dim, semantic_nc=semantic_nc, ngf=ngf, dim_seq=dim_seq, bev_grid_size=bev_grid_size,
|
54 |
+
aspect_ratio=aspect_ratio, num_upsampling_layers=num_upsampling_layers,
|
55 |
+
not_use_vae=not_use_vae, norm_G=norm_G)
|
56 |
+
print('backbone SPADEGenerator set up!')
|
57 |
+
|
58 |
+
# Set up the post module in the feature extractor.
|
59 |
+
self.post_module = None
|
60 |
+
|
61 |
+
# Set up the post neural renderer.
|
62 |
+
self.post_neural_renderer = None
|
63 |
+
sr_kwargs_total = dict(
|
64 |
+
channels=32,
|
65 |
+
img_resolution=img_resolution,
|
66 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
67 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
68 |
+
sr_kwargs_total.update(**sr_kwargs)
|
69 |
+
if img_resolution == 128:
|
70 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
71 |
+
**sr_kwargs_total)
|
72 |
+
elif img_resolution == 256:
|
73 |
+
self.post_neural_renderer = SuperresolutionHybrid4X_conststyle(
|
74 |
+
**sr_kwargs_total)
|
75 |
+
elif img_resolution == 512:
|
76 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
77 |
+
**sr_kwargs_total)
|
78 |
+
else:
|
79 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
80 |
+
|
81 |
+
# Set up the fully-connected layer head.
|
82 |
+
self.fc_head = OSGDecoder(
|
83 |
+
128 if xyz_pe else 64 , {
|
84 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
85 |
+
'decoder_output_dim': 32
|
86 |
+
},
|
87 |
+
hidden_dim=hidden_dim,
|
88 |
+
additional_layer_num=additional_layer_num
|
89 |
+
)
|
90 |
+
|
91 |
+
# Set up some rendering related arguments.
|
92 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
93 |
+
'resolution', 64)
|
94 |
+
self.rendering_kwargs = rendering_kwargs
|
95 |
+
|
96 |
+
def synthesis(self,
|
97 |
+
z,
|
98 |
+
c,
|
99 |
+
seg,
|
100 |
+
neural_rendering_resolution=None,
|
101 |
+
update_emas=False,
|
102 |
+
**synthesis_kwargs):
|
103 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
104 |
+
if self.rendering_kwargs.get('random_pose', False):
|
105 |
+
cam2world_matrix = None
|
106 |
+
|
107 |
+
if neural_rendering_resolution is None:
|
108 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
109 |
+
else:
|
110 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
111 |
+
|
112 |
+
xy_planes = self.backbone(z=z, input=seg)
|
113 |
+
if self.segmask:
|
114 |
+
xy_planes = xy_planes * seg[:, 0, ...][:, None, ...]
|
115 |
+
|
116 |
+
# import pdb;pdb.set_trace()
|
117 |
+
|
118 |
+
wp = z # in our case, we do not use wp.
|
119 |
+
|
120 |
+
rendering_result = self.renderer(
|
121 |
+
wp=wp,
|
122 |
+
feature_extractor=self.feature_extractor,
|
123 |
+
rendering_options=self.rendering_kwargs,
|
124 |
+
cam2world_matrix=cam2world_matrix,
|
125 |
+
position_encoder=None,
|
126 |
+
ref_representation=xy_planes,
|
127 |
+
post_module=self.post_module,
|
128 |
+
fc_head=self.fc_head)
|
129 |
+
|
130 |
+
feature_samples = rendering_result['composite_rgb']
|
131 |
+
depth_samples = rendering_result['composite_depth']
|
132 |
+
|
133 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
134 |
+
N = wp.shape[0]
|
135 |
+
H = W = self.neural_rendering_resolution
|
136 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
137 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
138 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
139 |
+
|
140 |
+
# Run the post neural renderer to get final image.
|
141 |
+
# Here, the post neural renderer is a super-resolution network.
|
142 |
+
rgb_image = feature_image[:, :3]
|
143 |
+
if self.interpolate_sr:
|
144 |
+
sr_image = torch.nn.functional.interpolate(rgb_image, size=(256, 256), mode='bilinear', align_corners=False)
|
145 |
+
else:
|
146 |
+
sr_image = self.post_neural_renderer(
|
147 |
+
rgb_image,
|
148 |
+
feature_image,
|
149 |
+
# wp,
|
150 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
151 |
+
**{
|
152 |
+
k: synthesis_kwargs[k]
|
153 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
154 |
+
})
|
155 |
+
|
156 |
+
return {
|
157 |
+
'image': sr_image,
|
158 |
+
'image_raw': rgb_image,
|
159 |
+
'image_depth': depth_image
|
160 |
+
}
|
161 |
+
|
162 |
+
def sample(self,
|
163 |
+
coordinates,
|
164 |
+
directions,
|
165 |
+
z,
|
166 |
+
c,
|
167 |
+
seg,
|
168 |
+
truncation_psi=1,
|
169 |
+
truncation_cutoff=None,
|
170 |
+
update_emas=False,
|
171 |
+
**synthesis_kwargs):
|
172 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
173 |
+
# Mostly used for extracting shapes.
|
174 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
175 |
+
xy_planes = self.backbone(z=z, input=seg)
|
176 |
+
wp = z
|
177 |
+
result = self.renderer.get_sigma_rgb(
|
178 |
+
wp=wp,
|
179 |
+
points=coordinates,
|
180 |
+
feature_extractor=self.feature_extractor,
|
181 |
+
fc_head=self.fc_head,
|
182 |
+
rendering_options=self.rendering_kwargs,
|
183 |
+
ref_representation=xy_planes,
|
184 |
+
post_module=self.post_module,
|
185 |
+
ray_dirs=directions,
|
186 |
+
cam_matrix=cam2world_matrix)
|
187 |
+
|
188 |
+
return result
|
189 |
+
|
190 |
+
def sample_mixed(self,
|
191 |
+
coordinates,
|
192 |
+
directions,
|
193 |
+
z, c, seg,
|
194 |
+
truncation_psi=1,
|
195 |
+
truncation_cutoff=None,
|
196 |
+
update_emas=False,
|
197 |
+
**synthesis_kwargs):
|
198 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
199 |
+
# instead of Gaussian noise 'z'.
|
200 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
201 |
+
xy_planes = self.backbone(z=z, input=seg)
|
202 |
+
wp = z
|
203 |
+
result = self.renderer.get_sigma_rgb(
|
204 |
+
wp=wp,
|
205 |
+
points=coordinates,
|
206 |
+
feature_extractor=self.feature_extractor,
|
207 |
+
fc_head=self.fc_head,
|
208 |
+
rendering_options=self.rendering_kwargs,
|
209 |
+
ref_representation=xy_planes,
|
210 |
+
post_module=self.post_module,
|
211 |
+
ray_dirs=directions,
|
212 |
+
cam_matrix=cam2world_matrix)
|
213 |
+
|
214 |
+
return result
|
215 |
+
|
216 |
+
def forward(self,
|
217 |
+
z,
|
218 |
+
c,
|
219 |
+
seg,
|
220 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
221 |
+
style_mixing_prob=0,
|
222 |
+
truncation_psi=1,
|
223 |
+
truncation_cutoff=None,
|
224 |
+
neural_rendering_resolution=None,
|
225 |
+
update_emas=False,
|
226 |
+
sample_mixed=False,
|
227 |
+
coordinates=None,
|
228 |
+
**synthesis_kwargs):
|
229 |
+
|
230 |
+
# Render a batch of generated images.
|
231 |
+
c_wp = c.clone()
|
232 |
+
if c_swapped is not None:
|
233 |
+
c_wp = c_swapped.clone()
|
234 |
+
|
235 |
+
if not sample_mixed:
|
236 |
+
gen_output = self.synthesis(
|
237 |
+
z,
|
238 |
+
c,
|
239 |
+
seg,
|
240 |
+
update_emas=update_emas,
|
241 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
242 |
+
**synthesis_kwargs)
|
243 |
+
|
244 |
+
return {
|
245 |
+
'wp': z,
|
246 |
+
'gen_output': gen_output,
|
247 |
+
}
|
248 |
+
|
249 |
+
else:
|
250 |
+
# Only for density regularization in training process.
|
251 |
+
assert coordinates is not None
|
252 |
+
sample_sigma = self.sample_mixed(coordinates,
|
253 |
+
torch.randn_like(coordinates),
|
254 |
+
z, c, seg,
|
255 |
+
update_emas=False)['sigma']
|
256 |
+
|
257 |
+
return {
|
258 |
+
'wp': z,
|
259 |
+
'sample_sigma': sample_sigma
|
260 |
+
}
|
261 |
+
|
262 |
+
|
263 |
+
class OSGDecoder(nn.Module):
|
264 |
+
"""Defines fully-connected layer head in EG3D."""
|
265 |
+
def __init__(self, n_features, options, hidden_dim=64, additional_layer_num=0):
|
266 |
+
super().__init__()
|
267 |
+
self.hidden_dim = hidden_dim
|
268 |
+
|
269 |
+
lst = []
|
270 |
+
lst.append(FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
|
271 |
+
lst.append(nn.Softplus())
|
272 |
+
for i in range(additional_layer_num):
|
273 |
+
lst.append(FullyConnectedLayer(self.hidden_dim, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']))
|
274 |
+
lst.append(nn.Softplus())
|
275 |
+
lst.append(FullyConnectedLayer(self.hidden_dim, 1+options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']))
|
276 |
+
self.net = nn.Sequential(*lst)
|
277 |
+
|
278 |
+
# self.net = nn.Sequential(
|
279 |
+
# FullyConnectedLayer(n_features,
|
280 |
+
# self.hidden_dim,
|
281 |
+
# lr_multiplier=options['decoder_lr_mul']),
|
282 |
+
# nn.Softplus(),
|
283 |
+
# FullyConnectedLayer(self.hidden_dim,
|
284 |
+
# 1 + options['decoder_output_dim'],
|
285 |
+
# lr_multiplier=options['decoder_lr_mul']))
|
286 |
+
|
287 |
+
def forward(self, point_features, wp=None, dirs=None):
|
288 |
+
# Aggregate features
|
289 |
+
# point_features.shape: [N, R, K, C].
|
290 |
+
# Average across 'X, Y, Z' planes.
|
291 |
+
|
292 |
+
N, R, K, C = point_features.shape
|
293 |
+
x = point_features.reshape(-1, point_features.shape[-1])
|
294 |
+
x = self.net(x)
|
295 |
+
x = x.view(N, -1, x.shape[-1])
|
296 |
+
|
297 |
+
# Uses sigmoid clamping from MipNeRF
|
298 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
299 |
+
sigma = x[..., 0:1]
|
300 |
+
|
301 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/eg3d_discriminator.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python 3.7
|
2 |
+
"""Contains the implementation of discriminator described in EG3D."""
|
3 |
+
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from third_party.stylegan2_official_ops import upfirdn2d
|
8 |
+
from models.utils.official_stylegan2_model_helper import DiscriminatorBlock
|
9 |
+
from models.utils.official_stylegan2_model_helper import MappingNetwork
|
10 |
+
from models.utils.official_stylegan2_model_helper import DiscriminatorEpilogue
|
11 |
+
|
12 |
+
|
13 |
+
class SingleDiscriminator(torch.nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
c_dim, # Conditioning label (C) dimensionality.
|
16 |
+
img_resolution, # Input resolution.
|
17 |
+
img_channels, # Number of input color channels.
|
18 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
19 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
20 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
21 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
22 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
23 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
24 |
+
sr_upsample_factor = 1, # Ignored for SingleDiscriminator
|
25 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
26 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
27 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.c_dim = c_dim
|
31 |
+
self.img_resolution = img_resolution
|
32 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
33 |
+
self.img_channels = img_channels
|
34 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
35 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
36 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
37 |
+
|
38 |
+
if cmap_dim is None:
|
39 |
+
cmap_dim = channels_dict[4]
|
40 |
+
if c_dim == 0:
|
41 |
+
cmap_dim = 0
|
42 |
+
|
43 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
44 |
+
cur_layer_idx = 0
|
45 |
+
for res in self.block_resolutions:
|
46 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
47 |
+
tmp_channels = channels_dict[res]
|
48 |
+
out_channels = channels_dict[res // 2]
|
49 |
+
use_fp16 = (res >= fp16_resolution)
|
50 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
51 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
52 |
+
setattr(self, f'b{res}', block)
|
53 |
+
cur_layer_idx += block.num_layers
|
54 |
+
if c_dim > 0:
|
55 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
56 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
57 |
+
|
58 |
+
def forward(self, img, c, update_emas=False, **block_kwargs):
|
59 |
+
img = img['image']
|
60 |
+
|
61 |
+
_ = update_emas # unused
|
62 |
+
x = None
|
63 |
+
for res in self.block_resolutions:
|
64 |
+
block = getattr(self, f'b{res}')
|
65 |
+
x, img = block(x, img, **block_kwargs)
|
66 |
+
|
67 |
+
cmap = None
|
68 |
+
if self.c_dim > 0:
|
69 |
+
cmap = self.mapping(None, c)
|
70 |
+
x = self.b4(x, img, cmap)
|
71 |
+
return x
|
72 |
+
|
73 |
+
def extra_repr(self):
|
74 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
75 |
+
|
76 |
+
#----------------------------------------------------------------------------
|
77 |
+
|
78 |
+
def filtered_resizing(image_orig_tensor, size, f, filter_mode='antialiased'):
|
79 |
+
if filter_mode == 'antialiased':
|
80 |
+
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
81 |
+
elif filter_mode == 'classic':
|
82 |
+
ada_filtered_64 = upfirdn2d.upsample2d(image_orig_tensor, f, up=2)
|
83 |
+
ada_filtered_64 = torch.nn.functional.interpolate(ada_filtered_64, size=(size * 2 + 2, size * 2 + 2), mode='bilinear', align_corners=False)
|
84 |
+
ada_filtered_64 = upfirdn2d.downsample2d(ada_filtered_64, f, down=2, flip_filter=True, padding=-1)
|
85 |
+
elif filter_mode == 'none':
|
86 |
+
ada_filtered_64 = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
87 |
+
elif type(filter_mode) == float:
|
88 |
+
assert 0 < filter_mode < 1
|
89 |
+
|
90 |
+
filtered = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
91 |
+
aliased = torch.nn.functional.interpolate(image_orig_tensor, size=(size, size), mode='bilinear', align_corners=False)
|
92 |
+
ada_filtered_64 = (1 - filter_mode) * aliased + (filter_mode) * filtered
|
93 |
+
|
94 |
+
return ada_filtered_64
|
95 |
+
|
96 |
+
#----------------------------------------------------------------------------
|
97 |
+
|
98 |
+
class DualDiscriminator(torch.nn.Module):
|
99 |
+
def __init__(self,
|
100 |
+
c_dim, # Conditioning label (C) dimensionality.
|
101 |
+
img_resolution, # Input resolution.
|
102 |
+
img_channels, # Number of input color channels.
|
103 |
+
bev_channels = 0,
|
104 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
105 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
106 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
107 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
108 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
109 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
110 |
+
disc_c_noise = 0, # Corrupt camera parameters with X std dev of noise before disc. pose conditioning.
|
111 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
112 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
113 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
114 |
+
):
|
115 |
+
super().__init__()
|
116 |
+
img_channels *= 2
|
117 |
+
|
118 |
+
self.c_dim = c_dim
|
119 |
+
self.img_resolution = img_resolution
|
120 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
121 |
+
self.img_channels = img_channels + bev_channels
|
122 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
123 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
124 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
125 |
+
|
126 |
+
if cmap_dim is None:
|
127 |
+
cmap_dim = channels_dict[4]
|
128 |
+
if c_dim == 0:
|
129 |
+
cmap_dim = 0
|
130 |
+
|
131 |
+
common_kwargs = dict(img_channels=self.img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
132 |
+
cur_layer_idx = 0
|
133 |
+
for res in self.block_resolutions:
|
134 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
135 |
+
tmp_channels = channels_dict[res]
|
136 |
+
out_channels = channels_dict[res // 2]
|
137 |
+
use_fp16 = (res >= fp16_resolution)
|
138 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
139 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
140 |
+
setattr(self, f'b{res}', block)
|
141 |
+
cur_layer_idx += block.num_layers
|
142 |
+
if c_dim > 0:
|
143 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
144 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
145 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
|
146 |
+
self.disc_c_noise = disc_c_noise
|
147 |
+
|
148 |
+
def forward(self, img, c, bev=None, update_emas=False, **block_kwargs):
|
149 |
+
image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter)
|
150 |
+
img = torch.cat([img['image'], image_raw], 1)
|
151 |
+
if bev is not None:
|
152 |
+
img = torch.cat([img, bev], 1)
|
153 |
+
|
154 |
+
_ = update_emas # unused
|
155 |
+
x = None
|
156 |
+
for res in self.block_resolutions:
|
157 |
+
block = getattr(self, f'b{res}')
|
158 |
+
x, img = block(x, img, **block_kwargs)
|
159 |
+
|
160 |
+
cmap = None
|
161 |
+
if self.c_dim > 0:
|
162 |
+
if self.disc_c_noise > 0: c += torch.randn_like(c) * c.std(0) * self.disc_c_noise
|
163 |
+
cmap = self.mapping(None, c)
|
164 |
+
x = self.b4(x, img, cmap)
|
165 |
+
return x
|
166 |
+
|
167 |
+
def extra_repr(self):
|
168 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
169 |
+
|
170 |
+
#----------------------------------------------------------------------------
|
171 |
+
|
172 |
+
class DummyDualDiscriminator(torch.nn.Module):
|
173 |
+
def __init__(self,
|
174 |
+
c_dim, # Conditioning label (C) dimensionality.
|
175 |
+
img_resolution, # Input resolution.
|
176 |
+
img_channels, # Number of input color channels.
|
177 |
+
architecture = 'resnet', # Architecture: 'orig', 'skip', 'resnet'.
|
178 |
+
channel_base = 32768, # Overall multiplier for the number of channels.
|
179 |
+
channel_max = 512, # Maximum number of channels in any layer.
|
180 |
+
num_fp16_res = 4, # Use FP16 for the N highest resolutions.
|
181 |
+
conv_clamp = 256, # Clamp the output of convolution layers to +-X, None = disable clamping.
|
182 |
+
cmap_dim = None, # Dimensionality of mapped conditioning label, None = default.
|
183 |
+
block_kwargs = {}, # Arguments for DiscriminatorBlock.
|
184 |
+
mapping_kwargs = {}, # Arguments for MappingNetwork.
|
185 |
+
epilogue_kwargs = {}, # Arguments for DiscriminatorEpilogue.
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
img_channels *= 2
|
189 |
+
|
190 |
+
self.c_dim = c_dim
|
191 |
+
self.img_resolution = img_resolution
|
192 |
+
self.img_resolution_log2 = int(np.log2(img_resolution))
|
193 |
+
self.img_channels = img_channels
|
194 |
+
self.block_resolutions = [2 ** i for i in range(self.img_resolution_log2, 2, -1)]
|
195 |
+
channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions + [4]}
|
196 |
+
fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)
|
197 |
+
|
198 |
+
if cmap_dim is None:
|
199 |
+
cmap_dim = channels_dict[4]
|
200 |
+
if c_dim == 0:
|
201 |
+
cmap_dim = 0
|
202 |
+
|
203 |
+
common_kwargs = dict(img_channels=img_channels, architecture=architecture, conv_clamp=conv_clamp)
|
204 |
+
cur_layer_idx = 0
|
205 |
+
for res in self.block_resolutions:
|
206 |
+
in_channels = channels_dict[res] if res < img_resolution else 0
|
207 |
+
tmp_channels = channels_dict[res]
|
208 |
+
out_channels = channels_dict[res // 2]
|
209 |
+
use_fp16 = (res >= fp16_resolution)
|
210 |
+
block = DiscriminatorBlock(in_channels, tmp_channels, out_channels, resolution=res,
|
211 |
+
first_layer_idx=cur_layer_idx, use_fp16=use_fp16, **block_kwargs, **common_kwargs)
|
212 |
+
setattr(self, f'b{res}', block)
|
213 |
+
cur_layer_idx += block.num_layers
|
214 |
+
if c_dim > 0:
|
215 |
+
self.mapping = MappingNetwork(z_dim=0, c_dim=c_dim, w_dim=cmap_dim, num_ws=None, w_avg_beta=None, **mapping_kwargs)
|
216 |
+
self.b4 = DiscriminatorEpilogue(channels_dict[4], cmap_dim=cmap_dim, resolution=4, **epilogue_kwargs, **common_kwargs)
|
217 |
+
self.register_buffer('resample_filter', upfirdn2d.setup_filter([1,3,3,1]))
|
218 |
+
|
219 |
+
self.raw_fade = 1
|
220 |
+
|
221 |
+
def forward(self, img, c, update_emas=False, **block_kwargs):
|
222 |
+
self.raw_fade = max(0, self.raw_fade - 1/(500000/32))
|
223 |
+
|
224 |
+
image_raw = filtered_resizing(img['image_raw'], size=img['image'].shape[-1], f=self.resample_filter) * self.raw_fade
|
225 |
+
img = torch.cat([img['image'], image_raw], 1)
|
226 |
+
|
227 |
+
_ = update_emas # unused
|
228 |
+
x = None
|
229 |
+
for res in self.block_resolutions:
|
230 |
+
block = getattr(self, f'b{res}')
|
231 |
+
x, img = block(x, img, **block_kwargs)
|
232 |
+
|
233 |
+
cmap = None
|
234 |
+
if self.c_dim > 0:
|
235 |
+
cmap = self.mapping(None, c)
|
236 |
+
x = self.b4(x, img, cmap)
|
237 |
+
return x
|
238 |
+
|
239 |
+
def extra_repr(self):
|
240 |
+
return f'c_dim={self.c_dim:d}, img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d}'
|
241 |
+
|
242 |
+
#----------------------------------------------------------------------------
|
243 |
+
|
models/eg3d_generator.py
ADDED
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.8
|
2 |
+
"""Contains the implementation of generator described in EG3D."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from models.utils.official_stylegan2_model_helper import Generator as StyleGAN2Backbone
|
7 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
8 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
11 |
+
from models.rendering.renderer import Renderer
|
12 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
13 |
+
|
14 |
+
class EG3DGenerator(nn.Module):
|
15 |
+
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
z_dim, # Input latent (Z) dimensionality.
|
19 |
+
c_dim, # Conditioning label (C) dimensionality.
|
20 |
+
w_dim, # Intermediate latent (W) dimensionality.
|
21 |
+
img_resolution, # Output resolution.
|
22 |
+
img_channels, # Number of output color channels.
|
23 |
+
sr_num_fp16_res=0, # Number of fp16 layers of SR Network.
|
24 |
+
mapping_kwargs={}, # Arguments for MappingNetwork.
|
25 |
+
rendering_kwargs={}, # Arguments for rendering.
|
26 |
+
sr_kwargs={}, # Arguments for SuperResolution Network.
|
27 |
+
**synthesis_kwargs, # Arguments for SynthesisNetwork.
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.z_dim = z_dim
|
31 |
+
self.c_dim = c_dim
|
32 |
+
self.w_dim = w_dim
|
33 |
+
self.img_resolution = img_resolution
|
34 |
+
self.img_channels = img_channels
|
35 |
+
|
36 |
+
# Set up the overall renderer.
|
37 |
+
self.renderer = Renderer()
|
38 |
+
|
39 |
+
# Set up the feature extractor.
|
40 |
+
self.feature_extractor = FeatureExtractor(ref_mode='tri_plane')
|
41 |
+
|
42 |
+
# Set up the reference representation generator.
|
43 |
+
self.backbone = StyleGAN2Backbone(z_dim,
|
44 |
+
c_dim,
|
45 |
+
w_dim,
|
46 |
+
img_resolution=256,
|
47 |
+
img_channels=32 * 3,
|
48 |
+
mapping_kwargs=mapping_kwargs,
|
49 |
+
**synthesis_kwargs)
|
50 |
+
|
51 |
+
# Set up the post module in the feature extractor.
|
52 |
+
self.post_module = None
|
53 |
+
|
54 |
+
# Set up the post neural renderer.
|
55 |
+
self.post_neural_renderer = None
|
56 |
+
sr_kwargs_total = dict(
|
57 |
+
channels=32,
|
58 |
+
img_resolution=img_resolution,
|
59 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
60 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
61 |
+
sr_kwargs_total.update(**sr_kwargs)
|
62 |
+
if img_resolution == 128:
|
63 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
64 |
+
**sr_kwargs_total)
|
65 |
+
elif img_resolution == 256:
|
66 |
+
self.post_neural_renderer = SuperresolutionHybrid4X(
|
67 |
+
**sr_kwargs_total)
|
68 |
+
elif img_resolution == 512:
|
69 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
70 |
+
**sr_kwargs_total)
|
71 |
+
else:
|
72 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
73 |
+
|
74 |
+
# Set up the fully-connected layer head.
|
75 |
+
self.fc_head = OSGDecoder(
|
76 |
+
32, {
|
77 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
78 |
+
'decoder_output_dim': 32
|
79 |
+
})
|
80 |
+
|
81 |
+
# Set up some rendering related arguments.
|
82 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
83 |
+
'resolution', 64)
|
84 |
+
self.rendering_kwargs = rendering_kwargs
|
85 |
+
|
86 |
+
def mapping(self,
|
87 |
+
z,
|
88 |
+
c,
|
89 |
+
truncation_psi=1,
|
90 |
+
truncation_cutoff=None,
|
91 |
+
update_emas=False):
|
92 |
+
if self.rendering_kwargs['c_gen_conditioning_zero']:
|
93 |
+
c = torch.zeros_like(c)
|
94 |
+
return self.backbone.mapping(z,
|
95 |
+
c *
|
96 |
+
self.rendering_kwargs.get('c_scale', 0),
|
97 |
+
truncation_psi=truncation_psi,
|
98 |
+
truncation_cutoff=truncation_cutoff,
|
99 |
+
update_emas=update_emas)
|
100 |
+
|
101 |
+
def synthesis(self,
|
102 |
+
wp,
|
103 |
+
c,
|
104 |
+
neural_rendering_resolution=None,
|
105 |
+
update_emas=False,
|
106 |
+
**synthesis_kwargs):
|
107 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
108 |
+
if self.rendering_kwargs.get('random_pose', False):
|
109 |
+
cam2world_matrix = None
|
110 |
+
|
111 |
+
if neural_rendering_resolution is None:
|
112 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
113 |
+
else:
|
114 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
115 |
+
|
116 |
+
tri_planes = self.backbone.synthesis(wp,
|
117 |
+
update_emas=update_emas,
|
118 |
+
**synthesis_kwargs)
|
119 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
120 |
+
tri_planes.shape[-2],
|
121 |
+
tri_planes.shape[-1])
|
122 |
+
|
123 |
+
rendering_result = self.renderer(
|
124 |
+
wp=wp,
|
125 |
+
feature_extractor=self.feature_extractor,
|
126 |
+
rendering_options=self.rendering_kwargs,
|
127 |
+
cam2world_matrix=cam2world_matrix,
|
128 |
+
position_encoder=None,
|
129 |
+
ref_representation=tri_planes,
|
130 |
+
post_module=self.post_module,
|
131 |
+
fc_head=self.fc_head)
|
132 |
+
|
133 |
+
feature_samples = rendering_result['composite_rgb']
|
134 |
+
depth_samples = rendering_result['composite_depth']
|
135 |
+
|
136 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
137 |
+
N = wp.shape[0]
|
138 |
+
H = W = self.neural_rendering_resolution
|
139 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
140 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
141 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
142 |
+
|
143 |
+
# Run the post neural renderer to get final image.
|
144 |
+
# Here, the post neural renderer is a super-resolution network.
|
145 |
+
rgb_image = feature_image[:, :3]
|
146 |
+
sr_image = self.post_neural_renderer(
|
147 |
+
rgb_image,
|
148 |
+
feature_image,
|
149 |
+
wp,
|
150 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
151 |
+
**{
|
152 |
+
k: synthesis_kwargs[k]
|
153 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
154 |
+
})
|
155 |
+
|
156 |
+
return {
|
157 |
+
'image': sr_image,
|
158 |
+
'image_raw': rgb_image,
|
159 |
+
'image_depth': depth_image
|
160 |
+
}
|
161 |
+
|
162 |
+
def sample(self,
|
163 |
+
coordinates,
|
164 |
+
directions,
|
165 |
+
z,
|
166 |
+
c,
|
167 |
+
truncation_psi=1,
|
168 |
+
truncation_cutoff=None,
|
169 |
+
update_emas=False,
|
170 |
+
**synthesis_kwargs):
|
171 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
172 |
+
# Mostly used for extracting shapes.
|
173 |
+
wp = self.mapping(z,
|
174 |
+
c,
|
175 |
+
truncation_psi=truncation_psi,
|
176 |
+
truncation_cutoff=truncation_cutoff,
|
177 |
+
update_emas=update_emas)
|
178 |
+
tri_planes = self.backbone.synthesis(wp,
|
179 |
+
update_emas=update_emas,
|
180 |
+
**synthesis_kwargs)
|
181 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
182 |
+
tri_planes.shape[-2],
|
183 |
+
tri_planes.shape[-1])
|
184 |
+
result = self.renderer.get_sigma_rgb(
|
185 |
+
wp=wp,
|
186 |
+
points=coordinates,
|
187 |
+
feature_extractor=self.feature_extractor,
|
188 |
+
fc_head=self.fc_head,
|
189 |
+
rendering_options=self.rendering_kwargs,
|
190 |
+
ref_representation=tri_planes,
|
191 |
+
post_module=self.post_module,
|
192 |
+
ray_dirs=directions)
|
193 |
+
|
194 |
+
return result
|
195 |
+
|
196 |
+
def sample_mixed(self,
|
197 |
+
coordinates,
|
198 |
+
directions,
|
199 |
+
wp,
|
200 |
+
truncation_psi=1,
|
201 |
+
truncation_cutoff=None,
|
202 |
+
update_emas=False,
|
203 |
+
**synthesis_kwargs):
|
204 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
205 |
+
# instead of Gaussian noise 'z'.
|
206 |
+
tri_planes = self.backbone.synthesis(wp,
|
207 |
+
update_emas=update_emas,
|
208 |
+
**synthesis_kwargs)
|
209 |
+
tri_planes = tri_planes.view(len(tri_planes), 3, -1,
|
210 |
+
tri_planes.shape[-2],
|
211 |
+
tri_planes.shape[-1])
|
212 |
+
|
213 |
+
result = self.renderer.get_sigma_rgb(
|
214 |
+
wp=wp,
|
215 |
+
points=coordinates,
|
216 |
+
feature_extractor=self.feature_extractor,
|
217 |
+
fc_head=self.fc_head,
|
218 |
+
rendering_options=self.rendering_kwargs,
|
219 |
+
ref_representation=tri_planes,
|
220 |
+
post_module=self.post_module,
|
221 |
+
ray_dirs=directions)
|
222 |
+
|
223 |
+
return result
|
224 |
+
|
225 |
+
def forward(self,
|
226 |
+
z,
|
227 |
+
c,
|
228 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
229 |
+
style_mixing_prob=0,
|
230 |
+
truncation_psi=1,
|
231 |
+
truncation_cutoff=None,
|
232 |
+
neural_rendering_resolution=None,
|
233 |
+
update_emas=False,
|
234 |
+
sample_mixed=False,
|
235 |
+
coordinates=None,
|
236 |
+
**synthesis_kwargs):
|
237 |
+
|
238 |
+
# Render a batch of generated images.
|
239 |
+
c_wp = c.clone()
|
240 |
+
if c_swapped is not None:
|
241 |
+
c_wp = c_swapped.clone()
|
242 |
+
wp = self.mapping(z,
|
243 |
+
c_wp,
|
244 |
+
truncation_psi=truncation_psi,
|
245 |
+
truncation_cutoff=truncation_cutoff,
|
246 |
+
update_emas=update_emas)
|
247 |
+
if style_mixing_prob > 0:
|
248 |
+
cutoff = torch.empty([], dtype=torch.int64,
|
249 |
+
device=wp.device).random_(1, wp.shape[1])
|
250 |
+
cutoff = torch.where(
|
251 |
+
torch.rand([], device=wp.device) < style_mixing_prob,
|
252 |
+
cutoff, torch.full_like(cutoff, wp.shape[1]))
|
253 |
+
wp[:, cutoff:] = self.mapping(torch.randn_like(z),
|
254 |
+
c,
|
255 |
+
update_emas=update_emas)[:, cutoff:]
|
256 |
+
if not sample_mixed:
|
257 |
+
gen_output = self.synthesis(
|
258 |
+
wp,
|
259 |
+
c,
|
260 |
+
update_emas=update_emas,
|
261 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
262 |
+
**synthesis_kwargs)
|
263 |
+
|
264 |
+
return {
|
265 |
+
'wp': wp,
|
266 |
+
'gen_output': gen_output,
|
267 |
+
}
|
268 |
+
|
269 |
+
else:
|
270 |
+
# Only for density regularization in training process.
|
271 |
+
assert coordinates is not None
|
272 |
+
sample_sigma = self.sample_mixed(coordinates,
|
273 |
+
torch.randn_like(coordinates),
|
274 |
+
wp,
|
275 |
+
update_emas=False)['sigma']
|
276 |
+
|
277 |
+
return {
|
278 |
+
'wp': wp,
|
279 |
+
'sample_sigma': sample_sigma
|
280 |
+
}
|
281 |
+
|
282 |
+
|
283 |
+
class OSGDecoder(nn.Module):
|
284 |
+
"""Defines fully-connected layer head in EG3D."""
|
285 |
+
def __init__(self, n_features, options):
|
286 |
+
super().__init__()
|
287 |
+
self.hidden_dim = 64
|
288 |
+
|
289 |
+
self.net = nn.Sequential(
|
290 |
+
FullyConnectedLayer(n_features,
|
291 |
+
self.hidden_dim,
|
292 |
+
lr_multiplier=options['decoder_lr_mul']),
|
293 |
+
nn.Softplus(),
|
294 |
+
FullyConnectedLayer(self.hidden_dim,
|
295 |
+
1 + options['decoder_output_dim'],
|
296 |
+
lr_multiplier=options['decoder_lr_mul']))
|
297 |
+
|
298 |
+
def forward(self, point_features, wp=None, dirs=None):
|
299 |
+
# Aggregate features
|
300 |
+
# point_features.shape: [N, 3, M, C].
|
301 |
+
# Average across 'X, Y, Z' planes.
|
302 |
+
point_features = point_features.mean(1)
|
303 |
+
x = point_features
|
304 |
+
|
305 |
+
N, M, C = x.shape
|
306 |
+
x = x.view(N * M, C)
|
307 |
+
|
308 |
+
x = self.net(x)
|
309 |
+
x = x.view(N, M, -1)
|
310 |
+
|
311 |
+
# Uses sigmoid clamping from MipNeRF
|
312 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
313 |
+
sigma = x[..., 0:1]
|
314 |
+
|
315 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/eg3d_generator_fv.py
ADDED
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.8
|
2 |
+
"""Contains the implementation of generator described in EG3D."""
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import numpy as np
|
7 |
+
from models.utils.official_stylegan2_model_helper import MappingNetwork
|
8 |
+
from models.utils.official_stylegan2_model_helper import FullyConnectedLayer
|
9 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid2X
|
10 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid4X
|
11 |
+
from models.utils.eg3d_superres import SuperresolutionHybrid8XDC
|
12 |
+
from models.rendering.renderer import Renderer
|
13 |
+
from models.rendering.feature_extractor import FeatureExtractor
|
14 |
+
from models.volumegan_generator import FeatureVolume
|
15 |
+
from models.volumegan_generator import PositionEncoder
|
16 |
+
|
17 |
+
|
18 |
+
class EG3DGeneratorFV(nn.Module):
|
19 |
+
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
# Input latent (Z) dimensionality.
|
23 |
+
z_dim,
|
24 |
+
# Conditioning label (C) dimensionality.
|
25 |
+
c_dim,
|
26 |
+
# Intermediate latent (W) dimensionality.
|
27 |
+
w_dim,
|
28 |
+
# Final output image resolution.
|
29 |
+
img_resolution,
|
30 |
+
# Number of output color channels.
|
31 |
+
img_channels,
|
32 |
+
# Number of fp16 layers of SR Network.
|
33 |
+
sr_num_fp16_res=0,
|
34 |
+
# Arguments for MappingNetwork.
|
35 |
+
mapping_kwargs={},
|
36 |
+
# Arguments for rendering.
|
37 |
+
rendering_kwargs={},
|
38 |
+
# Arguments for SuperResolution Network.
|
39 |
+
sr_kwargs={},
|
40 |
+
# Configs for FeatureVolume.
|
41 |
+
fv_cfg=dict(feat_res=32,
|
42 |
+
init_res=4,
|
43 |
+
base_channels=256,
|
44 |
+
output_channels=32,
|
45 |
+
w_dim=512),
|
46 |
+
# Configs for position encoder.
|
47 |
+
embed_cfg=dict(input_dim=3, max_freq_log2=10 - 1, N_freqs=10),
|
48 |
+
):
|
49 |
+
super().__init__()
|
50 |
+
self.z_dim = z_dim
|
51 |
+
self.c_dim = c_dim
|
52 |
+
self.w_dim = w_dim
|
53 |
+
self.img_resolution = img_resolution
|
54 |
+
self.img_channels = img_channels
|
55 |
+
|
56 |
+
# Set up mapping network.
|
57 |
+
# Here `num_ws = 2`: one for FeatureVolume Network injection and one for
|
58 |
+
# post_neural_renderer injection.
|
59 |
+
num_ws = 2
|
60 |
+
self.mapping_network = MappingNetwork(z_dim=z_dim,
|
61 |
+
c_dim=c_dim,
|
62 |
+
w_dim=w_dim,
|
63 |
+
num_ws=num_ws,
|
64 |
+
**mapping_kwargs)
|
65 |
+
|
66 |
+
# Set up the overall renderer.
|
67 |
+
self.renderer = Renderer()
|
68 |
+
|
69 |
+
# Set up the feature extractor.
|
70 |
+
self.feature_extractor = FeatureExtractor(ref_mode='feature_volume')
|
71 |
+
|
72 |
+
# Set up the reference representation generator.
|
73 |
+
self.ref_representation_generator = FeatureVolume(**fv_cfg)
|
74 |
+
|
75 |
+
# Set up the position encoder.
|
76 |
+
self.position_encoder = PositionEncoder(**embed_cfg)
|
77 |
+
|
78 |
+
# Set up the post module in the feature extractor.
|
79 |
+
self.post_module = None
|
80 |
+
|
81 |
+
# Set up the post neural renderer.
|
82 |
+
self.post_neural_renderer = None
|
83 |
+
sr_kwargs_total = dict(
|
84 |
+
channels=32,
|
85 |
+
img_resolution=img_resolution,
|
86 |
+
sr_num_fp16_res=sr_num_fp16_res,
|
87 |
+
sr_antialias=rendering_kwargs['sr_antialias'],)
|
88 |
+
sr_kwargs_total.update(**sr_kwargs)
|
89 |
+
if img_resolution == 128:
|
90 |
+
self.post_neural_renderer = SuperresolutionHybrid2X(
|
91 |
+
**sr_kwargs_total)
|
92 |
+
elif img_resolution == 256:
|
93 |
+
self.post_neural_renderer = SuperresolutionHybrid4X(
|
94 |
+
**sr_kwargs_total)
|
95 |
+
elif img_resolution == 512:
|
96 |
+
self.post_neural_renderer = SuperresolutionHybrid8XDC(
|
97 |
+
**sr_kwargs_total)
|
98 |
+
else:
|
99 |
+
raise TypeError(f'Unsupported image resolution: {img_resolution}!')
|
100 |
+
|
101 |
+
# Set up the fully-connected layer head.
|
102 |
+
self.fc_head = OSGDecoder(
|
103 |
+
32, {
|
104 |
+
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1),
|
105 |
+
'decoder_output_dim': 32
|
106 |
+
})
|
107 |
+
|
108 |
+
# Set up some rendering related arguments.
|
109 |
+
self.neural_rendering_resolution = rendering_kwargs.get(
|
110 |
+
'resolution', 64)
|
111 |
+
self.rendering_kwargs = rendering_kwargs
|
112 |
+
|
113 |
+
def mapping(self,
|
114 |
+
z,
|
115 |
+
c,
|
116 |
+
truncation_psi=1,
|
117 |
+
truncation_cutoff=None,
|
118 |
+
update_emas=False):
|
119 |
+
if self.rendering_kwargs['c_gen_conditioning_zero']:
|
120 |
+
c = torch.zeros_like(c)
|
121 |
+
return self.mapping_network(z,
|
122 |
+
c *
|
123 |
+
self.rendering_kwargs.get('c_scale', 0),
|
124 |
+
truncation_psi=truncation_psi,
|
125 |
+
truncation_cutoff=truncation_cutoff,
|
126 |
+
update_emas=update_emas)
|
127 |
+
|
128 |
+
def synthesis(self,
|
129 |
+
wp,
|
130 |
+
c,
|
131 |
+
neural_rendering_resolution=None,
|
132 |
+
update_emas=False,
|
133 |
+
**synthesis_kwargs):
|
134 |
+
cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
135 |
+
if self.rendering_kwargs.get('random_pose', False):
|
136 |
+
cam2world_matrix = None
|
137 |
+
|
138 |
+
if neural_rendering_resolution is None:
|
139 |
+
neural_rendering_resolution = self.neural_rendering_resolution
|
140 |
+
else:
|
141 |
+
self.neural_rendering_resolution = neural_rendering_resolution
|
142 |
+
|
143 |
+
feature_volume = self.ref_representation_generator(wp)
|
144 |
+
|
145 |
+
rendering_result = self.renderer(
|
146 |
+
wp=wp,
|
147 |
+
feature_extractor=self.feature_extractor,
|
148 |
+
rendering_options=self.rendering_kwargs,
|
149 |
+
cam2world_matrix=cam2world_matrix,
|
150 |
+
position_encoder=self.position_encoder,
|
151 |
+
ref_representation=feature_volume,
|
152 |
+
post_module=self.post_module,
|
153 |
+
fc_head=self.fc_head)
|
154 |
+
|
155 |
+
feature_samples = rendering_result['composite_rgb']
|
156 |
+
depth_samples = rendering_result['composite_depth']
|
157 |
+
|
158 |
+
# Reshape to keep consistent with 'raw' neural-rendered image.
|
159 |
+
N = wp.shape[0]
|
160 |
+
H = W = self.neural_rendering_resolution
|
161 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(
|
162 |
+
N, feature_samples.shape[-1], H, W).contiguous()
|
163 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
164 |
+
|
165 |
+
# Run the post neural renderer to get final image.
|
166 |
+
# Here, the post neural renderer is a super-resolution network.
|
167 |
+
rgb_image = feature_image[:, :3]
|
168 |
+
sr_image = self.post_neural_renderer(
|
169 |
+
rgb_image,
|
170 |
+
feature_image,
|
171 |
+
wp,
|
172 |
+
noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
|
173 |
+
**{
|
174 |
+
k: synthesis_kwargs[k]
|
175 |
+
for k in synthesis_kwargs.keys() if k != 'noise_mode'
|
176 |
+
})
|
177 |
+
|
178 |
+
return {
|
179 |
+
'image': sr_image,
|
180 |
+
'image_raw': rgb_image,
|
181 |
+
'image_depth': depth_image
|
182 |
+
}
|
183 |
+
|
184 |
+
def sample(self,
|
185 |
+
coordinates,
|
186 |
+
directions,
|
187 |
+
z,
|
188 |
+
c,
|
189 |
+
truncation_psi=1,
|
190 |
+
truncation_cutoff=None,
|
191 |
+
update_emas=False):
|
192 |
+
# Compute RGB features, density for arbitrary 3D coordinates.
|
193 |
+
# Mostly used for extracting shapes.
|
194 |
+
wp = self.mapping_network(z,
|
195 |
+
c,
|
196 |
+
truncation_psi=truncation_psi,
|
197 |
+
truncation_cutoff=truncation_cutoff,
|
198 |
+
update_emas=update_emas)
|
199 |
+
feature_volume = self.ref_representation_generator(wp)
|
200 |
+
result = self.renderer.get_sigma_rgb(
|
201 |
+
wp=wp,
|
202 |
+
points=coordinates,
|
203 |
+
feature_extractor=self.feature_extractor,
|
204 |
+
fc_head=self.fc_head,
|
205 |
+
rendering_options=self.rendering_kwargs,
|
206 |
+
ref_representation=feature_volume,
|
207 |
+
position_encoder=self.position_encoder,
|
208 |
+
post_module=self.post_module,
|
209 |
+
ray_dirs=directions)
|
210 |
+
|
211 |
+
return result
|
212 |
+
|
213 |
+
def sample_mixed(self,
|
214 |
+
coordinates,
|
215 |
+
directions,
|
216 |
+
wp):
|
217 |
+
# Same as function `self.sample()`, but expects latent vectors 'wp'
|
218 |
+
# instead of Gaussian noise 'z'.
|
219 |
+
feature_volume = self.ref_representation_generator(wp)
|
220 |
+
result = self.renderer.get_sigma_rgb(
|
221 |
+
wp=wp,
|
222 |
+
points=coordinates,
|
223 |
+
feature_extractor=self.feature_extractor,
|
224 |
+
fc_head=self.fc_head,
|
225 |
+
rendering_options=self.rendering_kwargs,
|
226 |
+
ref_representation=feature_volume,
|
227 |
+
position_encoder=self.position_encoder,
|
228 |
+
post_module=self.post_module,
|
229 |
+
ray_dirs=directions)
|
230 |
+
|
231 |
+
return result
|
232 |
+
|
233 |
+
def forward(self,
|
234 |
+
z,
|
235 |
+
c,
|
236 |
+
c_swapped=None, # `c_swapped` is swapped pose conditioning.
|
237 |
+
style_mixing_prob=0,
|
238 |
+
truncation_psi=1,
|
239 |
+
truncation_cutoff=None,
|
240 |
+
neural_rendering_resolution=None,
|
241 |
+
update_emas=False,
|
242 |
+
sample_mixed=False,
|
243 |
+
coordinates=None,
|
244 |
+
**synthesis_kwargs):
|
245 |
+
|
246 |
+
# Render a batch of generated images.
|
247 |
+
c_wp = c.clone()
|
248 |
+
if c_swapped is not None:
|
249 |
+
c_wp = c_swapped.clone()
|
250 |
+
wp = self.mapping_network(z,
|
251 |
+
c_wp,
|
252 |
+
truncation_psi=truncation_psi,
|
253 |
+
truncation_cutoff=truncation_cutoff,
|
254 |
+
update_emas=update_emas)
|
255 |
+
if style_mixing_prob > 0:
|
256 |
+
cutoff = torch.empty([], dtype=torch.int64,
|
257 |
+
device=wp.device).random_(1, wp.shape[1])
|
258 |
+
cutoff = torch.where(
|
259 |
+
torch.rand([], device=wp.device) < style_mixing_prob, cutoff,
|
260 |
+
torch.full_like(cutoff, wp.shape[1]))
|
261 |
+
wp[:, cutoff:] = self.mapping_network(
|
262 |
+
torch.randn_like(z), c, update_emas=update_emas)[:, cutoff:]
|
263 |
+
if not sample_mixed:
|
264 |
+
gen_output = self.synthesis(
|
265 |
+
wp,
|
266 |
+
c,
|
267 |
+
update_emas=update_emas,
|
268 |
+
neural_rendering_resolution=neural_rendering_resolution,
|
269 |
+
**synthesis_kwargs)
|
270 |
+
|
271 |
+
return {
|
272 |
+
'wp': wp,
|
273 |
+
'gen_output': gen_output,
|
274 |
+
}
|
275 |
+
|
276 |
+
else:
|
277 |
+
# Only for density regularization in training process.
|
278 |
+
assert coordinates is not None
|
279 |
+
sample_sigma = self.sample_mixed(coordinates,
|
280 |
+
torch.randn_like(coordinates),
|
281 |
+
wp)['sigma']
|
282 |
+
|
283 |
+
return {
|
284 |
+
'wp': wp,
|
285 |
+
'sample_sigma': sample_sigma
|
286 |
+
}
|
287 |
+
|
288 |
+
|
289 |
+
class OSGDecoder(nn.Module):
|
290 |
+
"""Defines fully-connected layer head in EG3D."""
|
291 |
+
def __init__(self, n_features, options):
|
292 |
+
super().__init__()
|
293 |
+
self.hidden_dim = 64
|
294 |
+
|
295 |
+
self.net = nn.Sequential(
|
296 |
+
FullyConnectedLayer(n_features,
|
297 |
+
self.hidden_dim,
|
298 |
+
lr_multiplier=options['decoder_lr_mul']),
|
299 |
+
nn.Softplus(),
|
300 |
+
FullyConnectedLayer(self.hidden_dim,
|
301 |
+
1 + options['decoder_output_dim'],
|
302 |
+
lr_multiplier=options['decoder_lr_mul']))
|
303 |
+
|
304 |
+
def forward(self, point_features, wp=None, dirs=None):
|
305 |
+
# point_features.shape: [N, C, M, 1].
|
306 |
+
point_features = point_features.squeeze(-1)
|
307 |
+
point_features = point_features.permute(0, 2, 1)
|
308 |
+
x = point_features
|
309 |
+
|
310 |
+
N, M, C = x.shape
|
311 |
+
x = x.reshape(N * M, C)
|
312 |
+
|
313 |
+
x = self.net(x)
|
314 |
+
x = x.reshape(N, M, -1)
|
315 |
+
|
316 |
+
# Uses sigmoid clamping from MipNeRF
|
317 |
+
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001
|
318 |
+
sigma = x[..., 0:1]
|
319 |
+
|
320 |
+
return {'rgb': rgb, 'sigma': sigma}
|
models/ghfeat_encoder.py
ADDED
@@ -0,0 +1,563 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the implementation of encoder used in GH-Feat (including IDInvert).
|
3 |
+
|
4 |
+
ResNet is used as the backbone.
|
5 |
+
|
6 |
+
GH-Feat paper: https://arxiv.org/pdf/2007.10379.pdf
|
7 |
+
IDInvert paper: https://arxiv.org/pdf/2004.00049.pdf
|
8 |
+
|
9 |
+
NOTE: Please use `latent_num` and `num_latents_per_head` to control the
|
10 |
+
inversion space, such as Y-space used in GH-Feat and W-space used in IDInvert.
|
11 |
+
In addition, IDInvert sets `use_fpn` and `use_sam` as `False` by default.
|
12 |
+
"""
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
import torch.nn.functional as F
|
19 |
+
import torch.distributed as dist
|
20 |
+
|
21 |
+
__all__ = ['GHFeatEncoder']
|
22 |
+
|
23 |
+
# Resolutions allowed.
|
24 |
+
_RESOLUTIONS_ALLOWED = [8, 16, 32, 64, 128, 256, 512, 1024]
|
25 |
+
|
26 |
+
# pylint: disable=missing-function-docstring
|
27 |
+
|
28 |
+
class BasicBlock(nn.Module):
|
29 |
+
"""Implementation of ResNet BasicBlock."""
|
30 |
+
|
31 |
+
expansion = 1
|
32 |
+
|
33 |
+
def __init__(self,
|
34 |
+
inplanes,
|
35 |
+
planes,
|
36 |
+
base_width=64,
|
37 |
+
stride=1,
|
38 |
+
groups=1,
|
39 |
+
dilation=1,
|
40 |
+
norm_layer=None,
|
41 |
+
downsample=None):
|
42 |
+
super().__init__()
|
43 |
+
if base_width != 64:
|
44 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
45 |
+
f'`base_width=64`, but {base_width} received!')
|
46 |
+
if stride not in [1, 2]:
|
47 |
+
raise ValueError(f'BasicBlock of ResNet only supports `stride=1` '
|
48 |
+
f'and `stride=2`, but {stride} received!')
|
49 |
+
if groups != 1:
|
50 |
+
raise ValueError(f'BasicBlock of ResNet only supports `groups=1`, '
|
51 |
+
f'but {groups} received!')
|
52 |
+
if dilation != 1:
|
53 |
+
raise ValueError(f'BasicBlock of ResNet only supports '
|
54 |
+
f'`dilation=1`, but {dilation} received!')
|
55 |
+
assert self.expansion == 1
|
56 |
+
|
57 |
+
self.stride = stride
|
58 |
+
if norm_layer is None:
|
59 |
+
norm_layer = nn.BatchNorm2d
|
60 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
61 |
+
out_channels=planes,
|
62 |
+
kernel_size=3,
|
63 |
+
stride=stride,
|
64 |
+
padding=1,
|
65 |
+
groups=1,
|
66 |
+
dilation=1,
|
67 |
+
bias=False)
|
68 |
+
self.bn1 = norm_layer(planes)
|
69 |
+
self.relu = nn.ReLU(inplace=True)
|
70 |
+
self.conv2 = nn.Conv2d(in_channels=planes,
|
71 |
+
out_channels=planes,
|
72 |
+
kernel_size=3,
|
73 |
+
stride=1,
|
74 |
+
padding=1,
|
75 |
+
groups=1,
|
76 |
+
dilation=1,
|
77 |
+
bias=False)
|
78 |
+
self.bn2 = norm_layer(planes)
|
79 |
+
self.downsample = downsample
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
83 |
+
|
84 |
+
out = self.conv1(x)
|
85 |
+
out = self.bn1(out)
|
86 |
+
out = self.relu(out)
|
87 |
+
|
88 |
+
out = self.conv2(out)
|
89 |
+
out = self.bn2(out)
|
90 |
+
out = self.relu(out + identity)
|
91 |
+
|
92 |
+
return out
|
93 |
+
|
94 |
+
|
95 |
+
class Bottleneck(nn.Module):
|
96 |
+
"""Implementation of ResNet Bottleneck."""
|
97 |
+
|
98 |
+
expansion = 4
|
99 |
+
|
100 |
+
def __init__(self,
|
101 |
+
inplanes,
|
102 |
+
planes,
|
103 |
+
base_width=64,
|
104 |
+
stride=1,
|
105 |
+
groups=1,
|
106 |
+
dilation=1,
|
107 |
+
norm_layer=None,
|
108 |
+
downsample=None):
|
109 |
+
super().__init__()
|
110 |
+
if stride not in [1, 2]:
|
111 |
+
raise ValueError(f'Bottleneck of ResNet only supports `stride=1` '
|
112 |
+
f'and `stride=2`, but {stride} received!')
|
113 |
+
|
114 |
+
width = int(planes * (base_width / 64)) * groups
|
115 |
+
self.stride = stride
|
116 |
+
if norm_layer is None:
|
117 |
+
norm_layer = nn.BatchNorm2d
|
118 |
+
self.conv1 = nn.Conv2d(in_channels=inplanes,
|
119 |
+
out_channels=width,
|
120 |
+
kernel_size=1,
|
121 |
+
stride=1,
|
122 |
+
padding=0,
|
123 |
+
dilation=1,
|
124 |
+
groups=1,
|
125 |
+
bias=False)
|
126 |
+
self.bn1 = norm_layer(width)
|
127 |
+
self.conv2 = nn.Conv2d(in_channels=width,
|
128 |
+
out_channels=width,
|
129 |
+
kernel_size=3,
|
130 |
+
stride=stride,
|
131 |
+
padding=dilation,
|
132 |
+
groups=groups,
|
133 |
+
dilation=dilation,
|
134 |
+
bias=False)
|
135 |
+
self.bn2 = norm_layer(width)
|
136 |
+
self.conv3 = nn.Conv2d(in_channels=width,
|
137 |
+
out_channels=planes * self.expansion,
|
138 |
+
kernel_size=1,
|
139 |
+
stride=1,
|
140 |
+
padding=0,
|
141 |
+
dilation=1,
|
142 |
+
groups=1,
|
143 |
+
bias=False)
|
144 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
145 |
+
self.relu = nn.ReLU(inplace=True)
|
146 |
+
self.downsample = downsample
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
identity = self.downsample(x) if self.downsample is not None else x
|
150 |
+
|
151 |
+
out = self.conv1(x)
|
152 |
+
out = self.bn1(out)
|
153 |
+
out = self.relu(out)
|
154 |
+
|
155 |
+
out = self.conv2(out)
|
156 |
+
out = self.bn2(out)
|
157 |
+
out = self.relu(out)
|
158 |
+
|
159 |
+
out = self.conv3(out)
|
160 |
+
out = self.bn3(out)
|
161 |
+
out = self.relu(out + identity)
|
162 |
+
|
163 |
+
return out
|
164 |
+
|
165 |
+
|
166 |
+
class GHFeatEncoder(nn.Module):
|
167 |
+
"""Define the ResNet-based encoder network for GAN inversion.
|
168 |
+
|
169 |
+
On top of the backbone, there are several task-heads to produce inverted
|
170 |
+
codes. Please use `latent_dim` and `num_latents_per_head` to define the
|
171 |
+
structure. For example, `latent_dim = [512] * 14` and
|
172 |
+
`num_latents_per_head = [4, 4, 6]` can be used for StyleGAN inversion with
|
173 |
+
14-layer latent codes, where 3 task heads (corresponding to 4, 4, 6 layers,
|
174 |
+
respectively) are used.
|
175 |
+
|
176 |
+
Settings for the encoder network:
|
177 |
+
|
178 |
+
(1) resolution: The resolution of the output image.
|
179 |
+
(2) latent_dim: Dimension of the latent space. A number (one code will be
|
180 |
+
produced), or a list of numbers regarding layer-wise latent codes.
|
181 |
+
(3) num_latents_per_head: Number of latents that is produced by each head.
|
182 |
+
(4) image_channels: Number of channels of the output image. (default: 3)
|
183 |
+
(5) final_res: Final resolution of the convolutional layers. (default: 4)
|
184 |
+
|
185 |
+
ResNet-related settings:
|
186 |
+
|
187 |
+
(1) network_depth: Depth of the network, like 18 for ResNet18. (default: 18)
|
188 |
+
(2) inplanes: Number of channels of the first convolutional layer.
|
189 |
+
(default: 64)
|
190 |
+
(3) groups: Groups of the convolution, used in ResNet. (default: 1)
|
191 |
+
(4) width_per_group: Number of channels per group, used in ResNet.
|
192 |
+
(default: 64)
|
193 |
+
(5) replace_stride_with_dilation: Whether to replace stride with dilation,
|
194 |
+
used in ResNet. (default: None)
|
195 |
+
(6) norm_layer: Normalization layer used in the encoder. If set as `None`,
|
196 |
+
`nn.BatchNorm2d` will be used. Also, please NOTE that when using batch
|
197 |
+
normalization, the batch size is required to be larger than one for
|
198 |
+
training. (default: nn.BatchNorm2d)
|
199 |
+
(7) max_channels: Maximum number of channels in each layer. (default: 512)
|
200 |
+
|
201 |
+
Task-head related settings:
|
202 |
+
|
203 |
+
(1) use_fpn: Whether to use Feature Pyramid Network (FPN) before outputting
|
204 |
+
the latent code. (default: True)
|
205 |
+
(2) fpn_channels: Number of channels used in FPN. (default: 512)
|
206 |
+
(3) use_sam: Whether to use Spatial Alignment Module (SAM) before outputting
|
207 |
+
the latent code. (default: True)
|
208 |
+
(4) sam_channels: Number of channels used in SAM. (default: 512)
|
209 |
+
"""
|
210 |
+
|
211 |
+
arch_settings = {
|
212 |
+
18: (BasicBlock, [2, 2, 2, 2]),
|
213 |
+
34: (BasicBlock, [3, 4, 6, 3]),
|
214 |
+
50: (Bottleneck, [3, 4, 6, 3]),
|
215 |
+
101: (Bottleneck, [3, 4, 23, 3]),
|
216 |
+
152: (Bottleneck, [3, 8, 36, 3])
|
217 |
+
}
|
218 |
+
|
219 |
+
def __init__(self,
|
220 |
+
resolution,
|
221 |
+
latent_dim,
|
222 |
+
num_latents_per_head,
|
223 |
+
image_channels=3,
|
224 |
+
final_res=4,
|
225 |
+
network_depth=18,
|
226 |
+
inplanes=64,
|
227 |
+
groups=1,
|
228 |
+
width_per_group=64,
|
229 |
+
replace_stride_with_dilation=None,
|
230 |
+
norm_layer=nn.BatchNorm2d,
|
231 |
+
max_channels=512,
|
232 |
+
use_fpn=True,
|
233 |
+
fpn_channels=512,
|
234 |
+
use_sam=True,
|
235 |
+
sam_channels=512):
|
236 |
+
super().__init__()
|
237 |
+
|
238 |
+
if resolution not in _RESOLUTIONS_ALLOWED:
|
239 |
+
raise ValueError(f'Invalid resolution: `{resolution}`!\n'
|
240 |
+
f'Resolutions allowed: {_RESOLUTIONS_ALLOWED}.')
|
241 |
+
if network_depth not in self.arch_settings:
|
242 |
+
raise ValueError(f'Invalid network depth: `{network_depth}`!\n'
|
243 |
+
f'Options allowed: '
|
244 |
+
f'{list(self.arch_settings.keys())}.')
|
245 |
+
if isinstance(latent_dim, int):
|
246 |
+
latent_dim = [latent_dim]
|
247 |
+
assert isinstance(latent_dim, (list, tuple))
|
248 |
+
assert isinstance(num_latents_per_head, (list, tuple))
|
249 |
+
assert sum(num_latents_per_head) == len(latent_dim)
|
250 |
+
|
251 |
+
self.resolution = resolution
|
252 |
+
self.latent_dim = latent_dim
|
253 |
+
self.num_latents_per_head = num_latents_per_head
|
254 |
+
self.num_heads = len(self.num_latents_per_head)
|
255 |
+
self.image_channels = image_channels
|
256 |
+
self.final_res = final_res
|
257 |
+
self.inplanes = inplanes
|
258 |
+
self.network_depth = network_depth
|
259 |
+
self.groups = groups
|
260 |
+
self.dilation = 1
|
261 |
+
self.base_width = width_per_group
|
262 |
+
self.replace_stride_with_dilation = replace_stride_with_dilation
|
263 |
+
if norm_layer is None:
|
264 |
+
norm_layer = nn.BatchNorm2d
|
265 |
+
if norm_layer == nn.BatchNorm2d and dist.is_initialized():
|
266 |
+
norm_layer = nn.SyncBatchNorm
|
267 |
+
self.norm_layer = norm_layer
|
268 |
+
self.max_channels = max_channels
|
269 |
+
self.use_fpn = use_fpn
|
270 |
+
self.fpn_channels = fpn_channels
|
271 |
+
self.use_sam = use_sam
|
272 |
+
self.sam_channels = sam_channels
|
273 |
+
|
274 |
+
block_fn, num_blocks_per_stage = self.arch_settings[network_depth]
|
275 |
+
|
276 |
+
self.num_stages = int(np.log2(resolution // final_res)) - 1
|
277 |
+
# Add one block for additional stages.
|
278 |
+
for i in range(len(num_blocks_per_stage), self.num_stages):
|
279 |
+
num_blocks_per_stage.append(1)
|
280 |
+
if replace_stride_with_dilation is None:
|
281 |
+
replace_stride_with_dilation = [False] * self.num_stages
|
282 |
+
|
283 |
+
# Backbone.
|
284 |
+
self.conv1 = nn.Conv2d(in_channels=self.image_channels,
|
285 |
+
out_channels=self.inplanes,
|
286 |
+
kernel_size=7,
|
287 |
+
stride=2,
|
288 |
+
padding=3,
|
289 |
+
bias=False)
|
290 |
+
self.bn1 = norm_layer(self.inplanes)
|
291 |
+
self.relu = nn.ReLU(inplace=True)
|
292 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
293 |
+
|
294 |
+
self.stage_channels = [self.inplanes]
|
295 |
+
self.stages = nn.ModuleList()
|
296 |
+
for i in range(self.num_stages):
|
297 |
+
inplanes = self.inplanes if i == 0 else planes * block_fn.expansion
|
298 |
+
planes = min(self.max_channels, self.inplanes * (2 ** i))
|
299 |
+
num_blocks = num_blocks_per_stage[i]
|
300 |
+
stride = 1 if i == 0 else 2
|
301 |
+
dilate = replace_stride_with_dilation[i]
|
302 |
+
self.stages.append(self._make_stage(block_fn=block_fn,
|
303 |
+
inplanes=inplanes,
|
304 |
+
planes=planes,
|
305 |
+
num_blocks=num_blocks,
|
306 |
+
stride=stride,
|
307 |
+
dilate=dilate))
|
308 |
+
self.stage_channels.append(planes * block_fn.expansion)
|
309 |
+
|
310 |
+
if self.num_heads > len(self.stage_channels):
|
311 |
+
raise ValueError('Number of task heads is larger than number of '
|
312 |
+
'stages! Please reduce the number of heads.')
|
313 |
+
|
314 |
+
# Task-head.
|
315 |
+
if self.num_heads == 1:
|
316 |
+
self.use_fpn = False
|
317 |
+
self.use_sam = False
|
318 |
+
|
319 |
+
if self.use_fpn:
|
320 |
+
fpn_pyramid_channels = self.stage_channels[-self.num_heads:]
|
321 |
+
self.fpn = FPN(pyramid_channels=fpn_pyramid_channels,
|
322 |
+
out_channels=self.fpn_channels)
|
323 |
+
if self.use_sam:
|
324 |
+
if self.use_fpn:
|
325 |
+
sam_pyramid_channels = [self.fpn_channels] * self.num_heads
|
326 |
+
else:
|
327 |
+
sam_pyramid_channels = self.stage_channels[-self.num_heads:]
|
328 |
+
self.sam = SAM(pyramid_channels=sam_pyramid_channels,
|
329 |
+
out_channels=self.sam_channels)
|
330 |
+
|
331 |
+
self.heads = nn.ModuleList()
|
332 |
+
for head_idx in range(self.num_heads):
|
333 |
+
# Parse in_channels.
|
334 |
+
if self.use_sam:
|
335 |
+
in_channels = self.sam_channels
|
336 |
+
elif self.use_fpn:
|
337 |
+
in_channels = self.fpn_channels
|
338 |
+
else:
|
339 |
+
in_channels = self.stage_channels[head_idx - self.num_heads]
|
340 |
+
in_channels = in_channels * final_res * final_res
|
341 |
+
|
342 |
+
# Parse out_channels.
|
343 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
344 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
345 |
+
out_channels = sum(self.latent_dim[start_latent_idx:end_latent_idx])
|
346 |
+
|
347 |
+
self.heads.append(CodeHead(in_channels=in_channels,
|
348 |
+
out_channels=out_channels,
|
349 |
+
norm_layer=self.norm_layer))
|
350 |
+
|
351 |
+
def _make_stage(self,
|
352 |
+
block_fn,
|
353 |
+
inplanes,
|
354 |
+
planes,
|
355 |
+
num_blocks,
|
356 |
+
stride,
|
357 |
+
dilate):
|
358 |
+
norm_layer = self.norm_layer
|
359 |
+
downsample = None
|
360 |
+
previous_dilation = self.dilation
|
361 |
+
if dilate:
|
362 |
+
self.dilation *= stride
|
363 |
+
stride = 1
|
364 |
+
if stride != 1 or inplanes != planes * block_fn.expansion:
|
365 |
+
downsample = nn.Sequential(
|
366 |
+
nn.Conv2d(in_channels=inplanes,
|
367 |
+
out_channels=planes * block_fn.expansion,
|
368 |
+
kernel_size=1,
|
369 |
+
stride=stride,
|
370 |
+
padding=0,
|
371 |
+
dilation=1,
|
372 |
+
groups=1,
|
373 |
+
bias=False),
|
374 |
+
norm_layer(planes * block_fn.expansion),
|
375 |
+
)
|
376 |
+
|
377 |
+
blocks = []
|
378 |
+
blocks.append(block_fn(inplanes=inplanes,
|
379 |
+
planes=planes,
|
380 |
+
base_width=self.base_width,
|
381 |
+
stride=stride,
|
382 |
+
groups=self.groups,
|
383 |
+
dilation=previous_dilation,
|
384 |
+
norm_layer=norm_layer,
|
385 |
+
downsample=downsample))
|
386 |
+
for _ in range(1, num_blocks):
|
387 |
+
blocks.append(block_fn(inplanes=planes * block_fn.expansion,
|
388 |
+
planes=planes,
|
389 |
+
base_width=self.base_width,
|
390 |
+
stride=1,
|
391 |
+
groups=self.groups,
|
392 |
+
dilation=self.dilation,
|
393 |
+
norm_layer=norm_layer,
|
394 |
+
downsample=None))
|
395 |
+
|
396 |
+
return nn.Sequential(*blocks)
|
397 |
+
|
398 |
+
def forward(self, x):
|
399 |
+
x = self.conv1(x)
|
400 |
+
x = self.bn1(x)
|
401 |
+
x = self.relu(x)
|
402 |
+
x = self.maxpool(x)
|
403 |
+
|
404 |
+
features = [x]
|
405 |
+
for i in range(self.num_stages):
|
406 |
+
x = self.stages[i](x)
|
407 |
+
features.append(x)
|
408 |
+
features = features[-self.num_heads:]
|
409 |
+
|
410 |
+
if self.use_fpn:
|
411 |
+
features = self.fpn(features)
|
412 |
+
if self.use_sam:
|
413 |
+
features = self.sam(features)
|
414 |
+
else:
|
415 |
+
final_size = features[-1].shape[2:]
|
416 |
+
for i in range(self.num_heads - 1):
|
417 |
+
features[i] = F.adaptive_avg_pool2d(features[i], final_size)
|
418 |
+
|
419 |
+
outputs = []
|
420 |
+
for head_idx in range(self.num_heads):
|
421 |
+
codes = self.heads[head_idx](features[head_idx])
|
422 |
+
start_latent_idx = sum(self.num_latents_per_head[:head_idx])
|
423 |
+
end_latent_idx = sum(self.num_latents_per_head[:head_idx + 1])
|
424 |
+
split_size = self.latent_dim[start_latent_idx:end_latent_idx]
|
425 |
+
outputs.extend(torch.split(codes, split_size, dim=1))
|
426 |
+
max_dim = max(self.latent_dim)
|
427 |
+
for i, dim in enumerate(self.latent_dim):
|
428 |
+
if dim < max_dim:
|
429 |
+
outputs[i] = F.pad(outputs[i], (0, max_dim - dim))
|
430 |
+
outputs[i] = outputs[i].unsqueeze(1)
|
431 |
+
|
432 |
+
return torch.cat(outputs, dim=1)
|
433 |
+
|
434 |
+
|
435 |
+
class FPN(nn.Module):
|
436 |
+
"""Implementation of Feature Pyramid Network (FPN).
|
437 |
+
|
438 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
439 |
+
Then, this module fuses these multi-level features from `top_level` to
|
440 |
+
`bottom_level`. In particular, starting from the `top_level`, each feature
|
441 |
+
is convoluted, upsampled, and fused into its previous feature (which is also
|
442 |
+
convoluted).
|
443 |
+
|
444 |
+
Args:
|
445 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
446 |
+
of channels of the feature from a particular level.
|
447 |
+
out_channels: Number of channels for each output.
|
448 |
+
|
449 |
+
Returns:
|
450 |
+
A list of feature maps, each of which has `out_channels` channels.
|
451 |
+
"""
|
452 |
+
|
453 |
+
def __init__(self, pyramid_channels, out_channels):
|
454 |
+
super().__init__()
|
455 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
456 |
+
self.num_levels = len(pyramid_channels)
|
457 |
+
|
458 |
+
self.lateral_layers = nn.ModuleList()
|
459 |
+
self.feature_layers = nn.ModuleList()
|
460 |
+
for i in range(self.num_levels):
|
461 |
+
in_channels = pyramid_channels[i]
|
462 |
+
self.lateral_layers.append(nn.Conv2d(in_channels=in_channels,
|
463 |
+
out_channels=out_channels,
|
464 |
+
kernel_size=3,
|
465 |
+
padding=1,
|
466 |
+
bias=True))
|
467 |
+
self.feature_layers.append(nn.Conv2d(in_channels=out_channels,
|
468 |
+
out_channels=out_channels,
|
469 |
+
kernel_size=3,
|
470 |
+
padding=1,
|
471 |
+
bias=True))
|
472 |
+
|
473 |
+
def forward(self, inputs):
|
474 |
+
if len(inputs) != self.num_levels:
|
475 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
476 |
+
|
477 |
+
# Project all related features to `out_channels`.
|
478 |
+
laterals = []
|
479 |
+
for i in range(self.num_levels):
|
480 |
+
laterals.append(self.lateral_layers[i](inputs[i]))
|
481 |
+
|
482 |
+
# Fusion, starting from `top_level`.
|
483 |
+
for i in range(self.num_levels - 1, 0, -1):
|
484 |
+
scale_factor = laterals[i - 1].shape[2] // laterals[i].shape[2]
|
485 |
+
laterals[i - 1] = (laterals[i - 1] +
|
486 |
+
F.interpolate(laterals[i],
|
487 |
+
mode='nearest',
|
488 |
+
scale_factor=scale_factor))
|
489 |
+
|
490 |
+
# Get outputs.
|
491 |
+
outputs = []
|
492 |
+
for i, lateral in enumerate(laterals):
|
493 |
+
outputs.append(self.feature_layers[i](lateral))
|
494 |
+
|
495 |
+
return outputs
|
496 |
+
|
497 |
+
|
498 |
+
class SAM(nn.Module):
|
499 |
+
"""Implementation of Spatial Alignment Module (SAM).
|
500 |
+
|
501 |
+
The input of this module is a pyramid of features with reducing resolutions.
|
502 |
+
Then this module downsamples all levels of feature to the minimum resolution
|
503 |
+
and fuses it with the smallest feature map.
|
504 |
+
|
505 |
+
Args:
|
506 |
+
pyramid_channels: A list of integers, each of which indicates the number
|
507 |
+
of channels of the feature from a particular level.
|
508 |
+
out_channels: Number of channels for each output.
|
509 |
+
|
510 |
+
Returns:
|
511 |
+
A list of feature maps, each of which has `out_channels` channels.
|
512 |
+
"""
|
513 |
+
|
514 |
+
def __init__(self, pyramid_channels, out_channels):
|
515 |
+
super().__init__()
|
516 |
+
assert isinstance(pyramid_channels, (list, tuple))
|
517 |
+
self.num_levels = len(pyramid_channels)
|
518 |
+
|
519 |
+
self.fusion_layers = nn.ModuleList()
|
520 |
+
for i in range(self.num_levels):
|
521 |
+
in_channels = pyramid_channels[i]
|
522 |
+
self.fusion_layers.append(nn.Conv2d(in_channels=in_channels,
|
523 |
+
out_channels=out_channels,
|
524 |
+
kernel_size=3,
|
525 |
+
padding=1,
|
526 |
+
bias=True))
|
527 |
+
|
528 |
+
def forward(self, inputs):
|
529 |
+
if len(inputs) != self.num_levels:
|
530 |
+
raise ValueError('Number of inputs and `num_levels` mismatch!')
|
531 |
+
|
532 |
+
output_res = inputs[-1].shape[2:]
|
533 |
+
for i in range(self.num_levels - 1, -1, -1):
|
534 |
+
if i != self.num_levels - 1:
|
535 |
+
inputs[i] = F.adaptive_avg_pool2d(inputs[i], output_res)
|
536 |
+
inputs[i] = self.fusion_layers[i](inputs[i])
|
537 |
+
if i != self.num_levels - 1:
|
538 |
+
inputs[i] = inputs[i] + inputs[-1]
|
539 |
+
|
540 |
+
return inputs
|
541 |
+
|
542 |
+
|
543 |
+
class CodeHead(nn.Module):
|
544 |
+
"""Implementation of the task-head to produce inverted codes."""
|
545 |
+
|
546 |
+
def __init__(self, in_channels, out_channels, norm_layer):
|
547 |
+
super().__init__()
|
548 |
+
self.fc = nn.Linear(in_channels, out_channels, bias=True)
|
549 |
+
if norm_layer is None:
|
550 |
+
self.norm = nn.Identity()
|
551 |
+
else:
|
552 |
+
self.norm = norm_layer(out_channels)
|
553 |
+
|
554 |
+
def forward(self, x):
|
555 |
+
if x.ndim > 2:
|
556 |
+
x = x.flatten(start_dim=1)
|
557 |
+
latent = self.fc(x)
|
558 |
+
latent = latent.unsqueeze(2).unsqueeze(3)
|
559 |
+
latent = self.norm(latent)
|
560 |
+
|
561 |
+
return latent.flatten(start_dim=1)
|
562 |
+
|
563 |
+
# pylint: enable=missing-function-docstring
|
models/inception_model.py
ADDED
@@ -0,0 +1,562 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the Inception V3 model, which is used for inference ONLY.
|
3 |
+
|
4 |
+
This file is mostly borrowed from `torchvision/models/inception.py`.
|
5 |
+
|
6 |
+
Inception model is widely used to compute FID or IS metric for evaluating
|
7 |
+
generative models. However, the pre-trained models from torchvision is slightly
|
8 |
+
different from the TensorFlow version
|
9 |
+
|
10 |
+
http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
|
11 |
+
|
12 |
+
which is used by the official FID implementation
|
13 |
+
|
14 |
+
https://github.com/bioinf-jku/TTUR
|
15 |
+
|
16 |
+
In particular:
|
17 |
+
|
18 |
+
(1) The number of classes in TensorFlow model is 1008 instead of 1000.
|
19 |
+
(2) The avg_pool() layers in TensorFlow model does not include the padded zero.
|
20 |
+
(3) The last Inception E Block in TensorFlow model use max_pool() instead of
|
21 |
+
avg_pool().
|
22 |
+
|
23 |
+
Hence, to align the evaluation results with those from TensorFlow
|
24 |
+
implementation, we modified the inception model to support both versions. Please
|
25 |
+
use `align_tf` argument to control the version.
|
26 |
+
"""
|
27 |
+
|
28 |
+
import warnings
|
29 |
+
|
30 |
+
import torch
|
31 |
+
import torch.nn as nn
|
32 |
+
import torch.nn.functional as F
|
33 |
+
import torch.distributed as dist
|
34 |
+
|
35 |
+
from utils.misc import download_url
|
36 |
+
|
37 |
+
__all__ = ['InceptionModel']
|
38 |
+
|
39 |
+
# pylint: disable=line-too-long
|
40 |
+
|
41 |
+
_MODEL_URL_SHA256 = {
|
42 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
43 |
+
'torchvision_official': (
|
44 |
+
'https://download.pytorch.org/models/inception_v3_google-1a9a5a14.pth',
|
45 |
+
'1a9a5a14f40645a370184bd54f4e8e631351e71399112b43ad0294a79da290c8' # hash sha256
|
46 |
+
),
|
47 |
+
|
48 |
+
# This model is provided by https://github.com/mseitzer/pytorch-fid
|
49 |
+
'tf_inception_v3': (
|
50 |
+
'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth',
|
51 |
+
'6726825d0af5f729cebd5821db510b11b1cfad8faad88a03f1befd49fb9129b2' # hash sha256
|
52 |
+
)
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
class InceptionModel(object):
|
57 |
+
"""Defines the Inception (V3) model.
|
58 |
+
|
59 |
+
This is a static class, which is used to avoid this model to be built
|
60 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
61 |
+
like computing FID. If training is required, please use the model from
|
62 |
+
`torchvision.models` or implement by yourself.
|
63 |
+
|
64 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
65 |
+
order and pixel range [-1, 1], and will also resize the images to shape
|
66 |
+
[299, 299] automatically. If your input is normalized by subtracting
|
67 |
+
(0.485, 0.456, 0.406) and dividing (0.229, 0.224, 0.225), please use
|
68 |
+
`transform_input` in the `forward()` function to un-normalize it.
|
69 |
+
"""
|
70 |
+
models = dict()
|
71 |
+
|
72 |
+
@staticmethod
|
73 |
+
def build_model(align_tf=True):
|
74 |
+
"""Builds the model and load pre-trained weights.
|
75 |
+
|
76 |
+
If `align_tf` is set as True, the model will predict 1008 classes, and
|
77 |
+
the pre-trained weight from `https://github.com/mseitzer/pytorch-fid`
|
78 |
+
will be loaded. Otherwise, the model will predict 1000 classes, and will
|
79 |
+
load the model from `torchvision`.
|
80 |
+
|
81 |
+
The built model supports following arguments when forwarding:
|
82 |
+
|
83 |
+
- transform_input: Whether to transform the input back to pixel range
|
84 |
+
(-1, 1). Please disable this argument if your input is already with
|
85 |
+
pixel range (-1, 1). (default: False)
|
86 |
+
- output_logits: Whether to output the categorical logits instead of
|
87 |
+
features. (default: False)
|
88 |
+
- remove_logits_bias: Whether to remove the bias when computing the
|
89 |
+
logits. The official implementation removes the bias by default.
|
90 |
+
Please refer to
|
91 |
+
`https://github.com/openai/improved-gan/blob/master/inception_score/model.py`.
|
92 |
+
(default: False)
|
93 |
+
- output_predictions: Whether to output the final predictions, i.e.,
|
94 |
+
`softmax(logits)`. (default: False)
|
95 |
+
"""
|
96 |
+
if align_tf:
|
97 |
+
num_classes = 1008
|
98 |
+
model_source = 'tf_inception_v3'
|
99 |
+
else:
|
100 |
+
num_classes = 1000
|
101 |
+
model_source = 'torchvision_official'
|
102 |
+
|
103 |
+
fingerprint = model_source
|
104 |
+
|
105 |
+
if fingerprint not in InceptionModel.models:
|
106 |
+
# Build model.
|
107 |
+
model = Inception3(num_classes=num_classes,
|
108 |
+
aux_logits=False,
|
109 |
+
init_weights=False,
|
110 |
+
align_tf=align_tf)
|
111 |
+
|
112 |
+
# Download pre-trained weights.
|
113 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
114 |
+
dist.barrier() # Download by chief.
|
115 |
+
|
116 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
117 |
+
filename = f'inception_model_{model_source}_{sha256}.pth'
|
118 |
+
model_path, hash_check = download_url(url,
|
119 |
+
filename=filename,
|
120 |
+
sha256=sha256)
|
121 |
+
state_dict = torch.load(model_path, map_location='cpu')
|
122 |
+
if hash_check is False:
|
123 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
124 |
+
f'`{url}` may be changed, or the downloading is '
|
125 |
+
f'interrupted. The loaded inception model may '
|
126 |
+
f'have unexpected behavior.')
|
127 |
+
|
128 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
129 |
+
dist.barrier() # Wait for other replicas.
|
130 |
+
|
131 |
+
# Load weights.
|
132 |
+
model.load_state_dict(state_dict, strict=False)
|
133 |
+
del state_dict
|
134 |
+
|
135 |
+
# For inference only.
|
136 |
+
model.eval().requires_grad_(False).cuda()
|
137 |
+
InceptionModel.models[fingerprint] = model
|
138 |
+
|
139 |
+
return InceptionModel.models[fingerprint]
|
140 |
+
|
141 |
+
# pylint: disable=missing-function-docstring
|
142 |
+
# pylint: disable=missing-class-docstring
|
143 |
+
# pylint: disable=super-with-arguments
|
144 |
+
# pylint: disable=consider-merging-isinstance
|
145 |
+
# pylint: disable=import-outside-toplevel
|
146 |
+
# pylint: disable=no-else-return
|
147 |
+
|
148 |
+
class Inception3(nn.Module):
|
149 |
+
|
150 |
+
def __init__(self, num_classes=1000, aux_logits=True, inception_blocks=None,
|
151 |
+
init_weights=True, align_tf=True):
|
152 |
+
super(Inception3, self).__init__()
|
153 |
+
if inception_blocks is None:
|
154 |
+
inception_blocks = [
|
155 |
+
BasicConv2d, InceptionA, InceptionB, InceptionC,
|
156 |
+
InceptionD, InceptionE, InceptionAux
|
157 |
+
]
|
158 |
+
assert len(inception_blocks) == 7
|
159 |
+
conv_block = inception_blocks[0]
|
160 |
+
inception_a = inception_blocks[1]
|
161 |
+
inception_b = inception_blocks[2]
|
162 |
+
inception_c = inception_blocks[3]
|
163 |
+
inception_d = inception_blocks[4]
|
164 |
+
inception_e = inception_blocks[5]
|
165 |
+
inception_aux = inception_blocks[6]
|
166 |
+
|
167 |
+
self.aux_logits = aux_logits
|
168 |
+
self.align_tf = align_tf
|
169 |
+
self.Conv2d_1a_3x3 = conv_block(3, 32, kernel_size=3, stride=2)
|
170 |
+
self.Conv2d_2a_3x3 = conv_block(32, 32, kernel_size=3)
|
171 |
+
self.Conv2d_2b_3x3 = conv_block(32, 64, kernel_size=3, padding=1)
|
172 |
+
self.Conv2d_3b_1x1 = conv_block(64, 80, kernel_size=1)
|
173 |
+
self.Conv2d_4a_3x3 = conv_block(80, 192, kernel_size=3)
|
174 |
+
self.Mixed_5b = inception_a(192, pool_features=32, align_tf=self.align_tf)
|
175 |
+
self.Mixed_5c = inception_a(256, pool_features=64, align_tf=self.align_tf)
|
176 |
+
self.Mixed_5d = inception_a(288, pool_features=64, align_tf=self.align_tf)
|
177 |
+
self.Mixed_6a = inception_b(288)
|
178 |
+
self.Mixed_6b = inception_c(768, channels_7x7=128, align_tf=self.align_tf)
|
179 |
+
self.Mixed_6c = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
180 |
+
self.Mixed_6d = inception_c(768, channels_7x7=160, align_tf=self.align_tf)
|
181 |
+
self.Mixed_6e = inception_c(768, channels_7x7=192, align_tf=self.align_tf)
|
182 |
+
if aux_logits:
|
183 |
+
self.AuxLogits = inception_aux(768, num_classes)
|
184 |
+
self.Mixed_7a = inception_d(768)
|
185 |
+
self.Mixed_7b = inception_e(1280, align_tf=self.align_tf)
|
186 |
+
self.Mixed_7c = inception_e(2048, use_max_pool=self.align_tf)
|
187 |
+
self.fc = nn.Linear(2048, num_classes)
|
188 |
+
if init_weights:
|
189 |
+
for m in self.modules():
|
190 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
191 |
+
import scipy.stats as stats
|
192 |
+
stddev = m.stddev if hasattr(m, 'stddev') else 0.1
|
193 |
+
X = stats.truncnorm(-2, 2, scale=stddev)
|
194 |
+
values = torch.as_tensor(X.rvs(m.weight.numel()), dtype=m.weight.dtype)
|
195 |
+
values = values.view(m.weight.size())
|
196 |
+
with torch.no_grad():
|
197 |
+
m.weight.copy_(values)
|
198 |
+
elif isinstance(m, nn.BatchNorm2d):
|
199 |
+
nn.init.constant_(m.weight, 1)
|
200 |
+
nn.init.constant_(m.bias, 0)
|
201 |
+
|
202 |
+
@staticmethod
|
203 |
+
def _transform_input(x, transform_input=False):
|
204 |
+
if transform_input:
|
205 |
+
x_ch0 = torch.unsqueeze(x[:, 0], 1) * (0.229 / 0.5) + (0.485 - 0.5) / 0.5
|
206 |
+
x_ch1 = torch.unsqueeze(x[:, 1], 1) * (0.224 / 0.5) + (0.456 - 0.5) / 0.5
|
207 |
+
x_ch2 = torch.unsqueeze(x[:, 2], 1) * (0.225 / 0.5) + (0.406 - 0.5) / 0.5
|
208 |
+
x = torch.cat((x_ch0, x_ch1, x_ch2), 1)
|
209 |
+
return x
|
210 |
+
|
211 |
+
def _forward(self,
|
212 |
+
x,
|
213 |
+
output_logits=False,
|
214 |
+
remove_logits_bias=False,
|
215 |
+
output_predictions=False):
|
216 |
+
# Upsample if necessary.
|
217 |
+
if x.shape[2] != 299 or x.shape[3] != 299:
|
218 |
+
if self.align_tf:
|
219 |
+
theta = torch.eye(2, 3).to(x)
|
220 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 299
|
221 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 299
|
222 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
223 |
+
grid = F.affine_grid(theta,
|
224 |
+
size=(x.shape[0], x.shape[1], 299, 299),
|
225 |
+
align_corners=False)
|
226 |
+
x = F.grid_sample(x, grid,
|
227 |
+
mode='bilinear',
|
228 |
+
padding_mode='border',
|
229 |
+
align_corners=False)
|
230 |
+
else:
|
231 |
+
x = F.interpolate(
|
232 |
+
x, size=(299, 299), mode='bilinear', align_corners=False)
|
233 |
+
if x.shape[1] == 1:
|
234 |
+
x = x.repeat((1, 3, 1, 1))
|
235 |
+
|
236 |
+
if self.align_tf:
|
237 |
+
x = (x * 127.5 + 127.5 - 128) / 128
|
238 |
+
|
239 |
+
# N x 3 x 299 x 299
|
240 |
+
x = self.Conv2d_1a_3x3(x)
|
241 |
+
# N x 32 x 149 x 149
|
242 |
+
x = self.Conv2d_2a_3x3(x)
|
243 |
+
# N x 32 x 147 x 147
|
244 |
+
x = self.Conv2d_2b_3x3(x)
|
245 |
+
# N x 64 x 147 x 147
|
246 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
247 |
+
# N x 64 x 73 x 73
|
248 |
+
x = self.Conv2d_3b_1x1(x)
|
249 |
+
# N x 80 x 73 x 73
|
250 |
+
x = self.Conv2d_4a_3x3(x)
|
251 |
+
# N x 192 x 71 x 71
|
252 |
+
x = F.max_pool2d(x, kernel_size=3, stride=2)
|
253 |
+
# N x 192 x 35 x 35
|
254 |
+
x = self.Mixed_5b(x)
|
255 |
+
# N x 256 x 35 x 35
|
256 |
+
x = self.Mixed_5c(x)
|
257 |
+
# N x 288 x 35 x 35
|
258 |
+
x = self.Mixed_5d(x)
|
259 |
+
# N x 288 x 35 x 35
|
260 |
+
x = self.Mixed_6a(x)
|
261 |
+
# N x 768 x 17 x 17
|
262 |
+
x = self.Mixed_6b(x)
|
263 |
+
# N x 768 x 17 x 17
|
264 |
+
x = self.Mixed_6c(x)
|
265 |
+
# N x 768 x 17 x 17
|
266 |
+
x = self.Mixed_6d(x)
|
267 |
+
# N x 768 x 17 x 17
|
268 |
+
x = self.Mixed_6e(x)
|
269 |
+
# N x 768 x 17 x 17
|
270 |
+
if self.training and self.aux_logits:
|
271 |
+
aux = self.AuxLogits(x)
|
272 |
+
else:
|
273 |
+
aux = None
|
274 |
+
# N x 768 x 17 x 17
|
275 |
+
x = self.Mixed_7a(x)
|
276 |
+
# N x 1280 x 8 x 8
|
277 |
+
x = self.Mixed_7b(x)
|
278 |
+
# N x 2048 x 8 x 8
|
279 |
+
x = self.Mixed_7c(x)
|
280 |
+
# N x 2048 x 8 x 8
|
281 |
+
# Adaptive average pooling
|
282 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
283 |
+
# N x 2048 x 1 x 1
|
284 |
+
x = F.dropout(x, training=self.training)
|
285 |
+
# N x 2048 x 1 x 1
|
286 |
+
x = torch.flatten(x, 1)
|
287 |
+
# N x 2048
|
288 |
+
if output_logits or output_predictions:
|
289 |
+
x = self.fc(x)
|
290 |
+
# N x 1000 (num_classes)
|
291 |
+
if remove_logits_bias:
|
292 |
+
x = x - self.fc.bias.view(1, -1)
|
293 |
+
if output_predictions:
|
294 |
+
x = F.softmax(x, dim=1)
|
295 |
+
return x, aux
|
296 |
+
|
297 |
+
def forward(self,
|
298 |
+
x,
|
299 |
+
transform_input=False,
|
300 |
+
output_logits=False,
|
301 |
+
remove_logits_bias=False,
|
302 |
+
output_predictions=False):
|
303 |
+
x = self._transform_input(x, transform_input)
|
304 |
+
x, aux = self._forward(
|
305 |
+
x, output_logits, remove_logits_bias, output_predictions)
|
306 |
+
if self.training and self.aux_logits:
|
307 |
+
return x, aux
|
308 |
+
else:
|
309 |
+
return x
|
310 |
+
|
311 |
+
|
312 |
+
class InceptionA(nn.Module):
|
313 |
+
|
314 |
+
def __init__(self, in_channels, pool_features, conv_block=None, align_tf=False):
|
315 |
+
super(InceptionA, self).__init__()
|
316 |
+
if conv_block is None:
|
317 |
+
conv_block = BasicConv2d
|
318 |
+
self.branch1x1 = conv_block(in_channels, 64, kernel_size=1)
|
319 |
+
|
320 |
+
self.branch5x5_1 = conv_block(in_channels, 48, kernel_size=1)
|
321 |
+
self.branch5x5_2 = conv_block(48, 64, kernel_size=5, padding=2)
|
322 |
+
|
323 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
324 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
325 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, padding=1)
|
326 |
+
|
327 |
+
self.branch_pool = conv_block(in_channels, pool_features, kernel_size=1)
|
328 |
+
self.pool_include_padding = not align_tf
|
329 |
+
|
330 |
+
def _forward(self, x):
|
331 |
+
branch1x1 = self.branch1x1(x)
|
332 |
+
|
333 |
+
branch5x5 = self.branch5x5_1(x)
|
334 |
+
branch5x5 = self.branch5x5_2(branch5x5)
|
335 |
+
|
336 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
337 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
338 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
339 |
+
|
340 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
341 |
+
count_include_pad=self.pool_include_padding)
|
342 |
+
branch_pool = self.branch_pool(branch_pool)
|
343 |
+
|
344 |
+
outputs = [branch1x1, branch5x5, branch3x3dbl, branch_pool]
|
345 |
+
return outputs
|
346 |
+
|
347 |
+
def forward(self, x):
|
348 |
+
outputs = self._forward(x)
|
349 |
+
return torch.cat(outputs, 1)
|
350 |
+
|
351 |
+
|
352 |
+
class InceptionB(nn.Module):
|
353 |
+
|
354 |
+
def __init__(self, in_channels, conv_block=None):
|
355 |
+
super(InceptionB, self).__init__()
|
356 |
+
if conv_block is None:
|
357 |
+
conv_block = BasicConv2d
|
358 |
+
self.branch3x3 = conv_block(in_channels, 384, kernel_size=3, stride=2)
|
359 |
+
|
360 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 64, kernel_size=1)
|
361 |
+
self.branch3x3dbl_2 = conv_block(64, 96, kernel_size=3, padding=1)
|
362 |
+
self.branch3x3dbl_3 = conv_block(96, 96, kernel_size=3, stride=2)
|
363 |
+
|
364 |
+
def _forward(self, x):
|
365 |
+
branch3x3 = self.branch3x3(x)
|
366 |
+
|
367 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
368 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
369 |
+
branch3x3dbl = self.branch3x3dbl_3(branch3x3dbl)
|
370 |
+
|
371 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
372 |
+
|
373 |
+
outputs = [branch3x3, branch3x3dbl, branch_pool]
|
374 |
+
return outputs
|
375 |
+
|
376 |
+
def forward(self, x):
|
377 |
+
outputs = self._forward(x)
|
378 |
+
return torch.cat(outputs, 1)
|
379 |
+
|
380 |
+
|
381 |
+
class InceptionC(nn.Module):
|
382 |
+
|
383 |
+
def __init__(self, in_channels, channels_7x7, conv_block=None, align_tf=False):
|
384 |
+
super(InceptionC, self).__init__()
|
385 |
+
if conv_block is None:
|
386 |
+
conv_block = BasicConv2d
|
387 |
+
self.branch1x1 = conv_block(in_channels, 192, kernel_size=1)
|
388 |
+
|
389 |
+
c7 = channels_7x7
|
390 |
+
self.branch7x7_1 = conv_block(in_channels, c7, kernel_size=1)
|
391 |
+
self.branch7x7_2 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
392 |
+
self.branch7x7_3 = conv_block(c7, 192, kernel_size=(7, 1), padding=(3, 0))
|
393 |
+
|
394 |
+
self.branch7x7dbl_1 = conv_block(in_channels, c7, kernel_size=1)
|
395 |
+
self.branch7x7dbl_2 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
396 |
+
self.branch7x7dbl_3 = conv_block(c7, c7, kernel_size=(1, 7), padding=(0, 3))
|
397 |
+
self.branch7x7dbl_4 = conv_block(c7, c7, kernel_size=(7, 1), padding=(3, 0))
|
398 |
+
self.branch7x7dbl_5 = conv_block(c7, 192, kernel_size=(1, 7), padding=(0, 3))
|
399 |
+
|
400 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
401 |
+
self.pool_include_padding = not align_tf
|
402 |
+
|
403 |
+
def _forward(self, x):
|
404 |
+
branch1x1 = self.branch1x1(x)
|
405 |
+
|
406 |
+
branch7x7 = self.branch7x7_1(x)
|
407 |
+
branch7x7 = self.branch7x7_2(branch7x7)
|
408 |
+
branch7x7 = self.branch7x7_3(branch7x7)
|
409 |
+
|
410 |
+
branch7x7dbl = self.branch7x7dbl_1(x)
|
411 |
+
branch7x7dbl = self.branch7x7dbl_2(branch7x7dbl)
|
412 |
+
branch7x7dbl = self.branch7x7dbl_3(branch7x7dbl)
|
413 |
+
branch7x7dbl = self.branch7x7dbl_4(branch7x7dbl)
|
414 |
+
branch7x7dbl = self.branch7x7dbl_5(branch7x7dbl)
|
415 |
+
|
416 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
417 |
+
count_include_pad=self.pool_include_padding)
|
418 |
+
branch_pool = self.branch_pool(branch_pool)
|
419 |
+
|
420 |
+
outputs = [branch1x1, branch7x7, branch7x7dbl, branch_pool]
|
421 |
+
return outputs
|
422 |
+
|
423 |
+
def forward(self, x):
|
424 |
+
outputs = self._forward(x)
|
425 |
+
return torch.cat(outputs, 1)
|
426 |
+
|
427 |
+
|
428 |
+
class InceptionD(nn.Module):
|
429 |
+
|
430 |
+
def __init__(self, in_channels, conv_block=None):
|
431 |
+
super(InceptionD, self).__init__()
|
432 |
+
if conv_block is None:
|
433 |
+
conv_block = BasicConv2d
|
434 |
+
self.branch3x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
435 |
+
self.branch3x3_2 = conv_block(192, 320, kernel_size=3, stride=2)
|
436 |
+
|
437 |
+
self.branch7x7x3_1 = conv_block(in_channels, 192, kernel_size=1)
|
438 |
+
self.branch7x7x3_2 = conv_block(192, 192, kernel_size=(1, 7), padding=(0, 3))
|
439 |
+
self.branch7x7x3_3 = conv_block(192, 192, kernel_size=(7, 1), padding=(3, 0))
|
440 |
+
self.branch7x7x3_4 = conv_block(192, 192, kernel_size=3, stride=2)
|
441 |
+
|
442 |
+
def _forward(self, x):
|
443 |
+
branch3x3 = self.branch3x3_1(x)
|
444 |
+
branch3x3 = self.branch3x3_2(branch3x3)
|
445 |
+
|
446 |
+
branch7x7x3 = self.branch7x7x3_1(x)
|
447 |
+
branch7x7x3 = self.branch7x7x3_2(branch7x7x3)
|
448 |
+
branch7x7x3 = self.branch7x7x3_3(branch7x7x3)
|
449 |
+
branch7x7x3 = self.branch7x7x3_4(branch7x7x3)
|
450 |
+
|
451 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=2)
|
452 |
+
outputs = [branch3x3, branch7x7x3, branch_pool]
|
453 |
+
return outputs
|
454 |
+
|
455 |
+
def forward(self, x):
|
456 |
+
outputs = self._forward(x)
|
457 |
+
return torch.cat(outputs, 1)
|
458 |
+
|
459 |
+
|
460 |
+
class InceptionE(nn.Module):
|
461 |
+
|
462 |
+
def __init__(self, in_channels, conv_block=None, align_tf=False, use_max_pool=False):
|
463 |
+
super(InceptionE, self).__init__()
|
464 |
+
if conv_block is None:
|
465 |
+
conv_block = BasicConv2d
|
466 |
+
self.branch1x1 = conv_block(in_channels, 320, kernel_size=1)
|
467 |
+
|
468 |
+
self.branch3x3_1 = conv_block(in_channels, 384, kernel_size=1)
|
469 |
+
self.branch3x3_2a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
470 |
+
self.branch3x3_2b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
471 |
+
|
472 |
+
self.branch3x3dbl_1 = conv_block(in_channels, 448, kernel_size=1)
|
473 |
+
self.branch3x3dbl_2 = conv_block(448, 384, kernel_size=3, padding=1)
|
474 |
+
self.branch3x3dbl_3a = conv_block(384, 384, kernel_size=(1, 3), padding=(0, 1))
|
475 |
+
self.branch3x3dbl_3b = conv_block(384, 384, kernel_size=(3, 1), padding=(1, 0))
|
476 |
+
|
477 |
+
self.branch_pool = conv_block(in_channels, 192, kernel_size=1)
|
478 |
+
self.pool_include_padding = not align_tf
|
479 |
+
self.use_max_pool = use_max_pool
|
480 |
+
|
481 |
+
def _forward(self, x):
|
482 |
+
branch1x1 = self.branch1x1(x)
|
483 |
+
|
484 |
+
branch3x3 = self.branch3x3_1(x)
|
485 |
+
branch3x3 = [
|
486 |
+
self.branch3x3_2a(branch3x3),
|
487 |
+
self.branch3x3_2b(branch3x3),
|
488 |
+
]
|
489 |
+
branch3x3 = torch.cat(branch3x3, 1)
|
490 |
+
|
491 |
+
branch3x3dbl = self.branch3x3dbl_1(x)
|
492 |
+
branch3x3dbl = self.branch3x3dbl_2(branch3x3dbl)
|
493 |
+
branch3x3dbl = [
|
494 |
+
self.branch3x3dbl_3a(branch3x3dbl),
|
495 |
+
self.branch3x3dbl_3b(branch3x3dbl),
|
496 |
+
]
|
497 |
+
branch3x3dbl = torch.cat(branch3x3dbl, 1)
|
498 |
+
|
499 |
+
if self.use_max_pool:
|
500 |
+
branch_pool = F.max_pool2d(x, kernel_size=3, stride=1, padding=1)
|
501 |
+
else:
|
502 |
+
branch_pool = F.avg_pool2d(x, kernel_size=3, stride=1, padding=1,
|
503 |
+
count_include_pad=self.pool_include_padding)
|
504 |
+
branch_pool = self.branch_pool(branch_pool)
|
505 |
+
|
506 |
+
outputs = [branch1x1, branch3x3, branch3x3dbl, branch_pool]
|
507 |
+
return outputs
|
508 |
+
|
509 |
+
def forward(self, x):
|
510 |
+
outputs = self._forward(x)
|
511 |
+
return torch.cat(outputs, 1)
|
512 |
+
|
513 |
+
|
514 |
+
class InceptionAux(nn.Module):
|
515 |
+
|
516 |
+
def __init__(self, in_channels, num_classes, conv_block=None):
|
517 |
+
super(InceptionAux, self).__init__()
|
518 |
+
if conv_block is None:
|
519 |
+
conv_block = BasicConv2d
|
520 |
+
self.conv0 = conv_block(in_channels, 128, kernel_size=1)
|
521 |
+
self.conv1 = conv_block(128, 768, kernel_size=5)
|
522 |
+
self.conv1.stddev = 0.01
|
523 |
+
self.fc = nn.Linear(768, num_classes)
|
524 |
+
self.fc.stddev = 0.001
|
525 |
+
|
526 |
+
def forward(self, x):
|
527 |
+
# N x 768 x 17 x 17
|
528 |
+
x = F.avg_pool2d(x, kernel_size=5, stride=3)
|
529 |
+
# N x 768 x 5 x 5
|
530 |
+
x = self.conv0(x)
|
531 |
+
# N x 128 x 5 x 5
|
532 |
+
x = self.conv1(x)
|
533 |
+
# N x 768 x 1 x 1
|
534 |
+
# Adaptive average pooling
|
535 |
+
x = F.adaptive_avg_pool2d(x, (1, 1))
|
536 |
+
# N x 768 x 1 x 1
|
537 |
+
x = torch.flatten(x, 1)
|
538 |
+
# N x 768
|
539 |
+
x = self.fc(x)
|
540 |
+
# N x 1000
|
541 |
+
return x
|
542 |
+
|
543 |
+
|
544 |
+
class BasicConv2d(nn.Module):
|
545 |
+
|
546 |
+
def __init__(self, in_channels, out_channels, **kwargs):
|
547 |
+
super(BasicConv2d, self).__init__()
|
548 |
+
self.conv = nn.Conv2d(in_channels, out_channels, bias=False, **kwargs)
|
549 |
+
self.bn = nn.BatchNorm2d(out_channels, eps=0.001)
|
550 |
+
|
551 |
+
def forward(self, x):
|
552 |
+
x = self.conv(x)
|
553 |
+
x = self.bn(x)
|
554 |
+
return F.relu(x, inplace=True)
|
555 |
+
|
556 |
+
# pylint: enable=line-too-long
|
557 |
+
# pylint: enable=missing-function-docstring
|
558 |
+
# pylint: enable=missing-class-docstring
|
559 |
+
# pylint: enable=super-with-arguments
|
560 |
+
# pylint: enable=consider-merging-isinstance
|
561 |
+
# pylint: enable=import-outside-toplevel
|
562 |
+
# pylint: enable=no-else-return
|
models/perceptual_model.py
ADDED
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# python3.7
|
2 |
+
"""Contains the VGG16 model, which is used for inference ONLY.
|
3 |
+
|
4 |
+
VGG16 is commonly used for perceptual feature extraction. The model implemented
|
5 |
+
in this file can be used for evaluation (like computing LPIPS, perceptual path
|
6 |
+
length, etc.), OR be used in training for loss computation (like perceptual
|
7 |
+
loss, etc.).
|
8 |
+
|
9 |
+
The pre-trained model is officially shared by
|
10 |
+
|
11 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/
|
12 |
+
|
13 |
+
and ported by
|
14 |
+
|
15 |
+
https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt
|
16 |
+
|
17 |
+
Compared to the official VGG16 model, this ported model also support evaluating
|
18 |
+
LPIPS, which is introduced in
|
19 |
+
|
20 |
+
https://github.com/richzhang/PerceptualSimilarity
|
21 |
+
"""
|
22 |
+
|
23 |
+
import warnings
|
24 |
+
import numpy as np
|
25 |
+
|
26 |
+
import torch
|
27 |
+
import torch.nn as nn
|
28 |
+
import torch.nn.functional as F
|
29 |
+
import torch.distributed as dist
|
30 |
+
|
31 |
+
from utils.misc import download_url
|
32 |
+
|
33 |
+
__all__ = ['PerceptualModel']
|
34 |
+
|
35 |
+
# pylint: disable=line-too-long
|
36 |
+
_MODEL_URL_SHA256 = {
|
37 |
+
# This model is provided by `torchvision`, which is ported from TensorFlow.
|
38 |
+
'torchvision_official': (
|
39 |
+
'https://download.pytorch.org/models/vgg16-397923af.pth',
|
40 |
+
'397923af8e79cdbb6a7127f12361acd7a2f83e06b05044ddf496e83de57a5bf0' # hash sha256
|
41 |
+
),
|
42 |
+
|
43 |
+
# This model is provided by https://github.com/NVlabs/stylegan2-ada-pytorch
|
44 |
+
'vgg_perceptual_lpips': (
|
45 |
+
'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt',
|
46 |
+
'b437eb095feaeb0b83eb3fa11200ebca4548ee39a07fb944a417ddc516cc07c3' # hash sha256
|
47 |
+
)
|
48 |
+
}
|
49 |
+
# pylint: enable=line-too-long
|
50 |
+
|
51 |
+
|
52 |
+
class PerceptualModel(object):
|
53 |
+
"""Defines the perceptual model, which is based on VGG16 structure.
|
54 |
+
|
55 |
+
This is a static class, which is used to avoid this model to be built
|
56 |
+
repeatedly. Consequently, this model is particularly used for inference,
|
57 |
+
like computing LPIPS, or for loss computation, like perceptual loss. If
|
58 |
+
training is required, please use the model from `torchvision.models` or
|
59 |
+
implement by yourself.
|
60 |
+
|
61 |
+
NOTE: The pre-trained model assumes the inputs to be with `RGB` channel
|
62 |
+
order and pixel range [-1, 1], and will NOT resize the input automatically
|
63 |
+
if only perceptual feature is needed.
|
64 |
+
"""
|
65 |
+
models = dict()
|
66 |
+
|
67 |
+
@staticmethod
|
68 |
+
def build_model(use_torchvision=False, no_top=True, enable_lpips=True):
|
69 |
+
"""Builds the model and load pre-trained weights.
|
70 |
+
|
71 |
+
1. If `use_torchvision` is set as True, the model released by
|
72 |
+
`torchvision` will be loaded, otherwise, the model released by
|
73 |
+
https://www.robots.ox.ac.uk/~vgg/research/very_deep/ will be used.
|
74 |
+
(default: False)
|
75 |
+
|
76 |
+
2. To save computing resources, these is an option to only load the
|
77 |
+
backbone (i.e., without the last three fully-connected layers). This
|
78 |
+
is commonly used for perceptual loss or LPIPS loss computation.
|
79 |
+
Please use argument `no_top` to control this. (default: True)
|
80 |
+
|
81 |
+
3. For LPIPS loss computation, some additional weights (which is used
|
82 |
+
for balancing the features from different resolutions) are employed
|
83 |
+
on top of the original VGG16 backbone. Details can be found at
|
84 |
+
https://github.com/richzhang/PerceptualSimilarity. Please use
|
85 |
+
`enable_lpips` to enable this feature. (default: True)
|
86 |
+
|
87 |
+
The built model supports following arguments when forwarding:
|
88 |
+
|
89 |
+
- resize_input: Whether to resize the input image to size [224, 224]
|
90 |
+
before forwarding. For feature-based computation (i.e., only
|
91 |
+
convolutional layers are used), image resizing is not essential.
|
92 |
+
(default: False)
|
93 |
+
- return_tensor: This field resolves the model behavior. Following
|
94 |
+
options are supported:
|
95 |
+
`feature1`: Before the first max pooling layer.
|
96 |
+
`pool1`: After the first max pooling layer.
|
97 |
+
`feature2`: Before the second max pooling layer.
|
98 |
+
`pool2`: After the second max pooling layer.
|
99 |
+
`feature3`: Before the third max pooling layer.
|
100 |
+
`pool3`: After the third max pooling layer.
|
101 |
+
`feature4`: Before the fourth max pooling layer.
|
102 |
+
`pool4`: After the fourth max pooling layer.
|
103 |
+
`feature5`: Before the fifth max pooling layer.
|
104 |
+
`pool5`: After the fifth max pooling layer.
|
105 |
+
`flatten`: The flattened feature, after `adaptive_avgpool`.
|
106 |
+
`feature`: The 4096d feature for logits computation. (default)
|
107 |
+
`logits`: The 1000d categorical logits.
|
108 |
+
`prediction`: The 1000d predicted probability.
|
109 |
+
`lpips`: The LPIPS score between two input images.
|
110 |
+
"""
|
111 |
+
if use_torchvision:
|
112 |
+
model_source = 'torchvision_official'
|
113 |
+
align_tf_resize = False
|
114 |
+
is_torch_script = False
|
115 |
+
else:
|
116 |
+
model_source = 'vgg_perceptual_lpips'
|
117 |
+
align_tf_resize = True
|
118 |
+
is_torch_script = True
|
119 |
+
|
120 |
+
if enable_lpips and model_source != 'vgg_perceptual_lpips':
|
121 |
+
warnings.warn('The pre-trained model officially released by '
|
122 |
+
'`torchvision` does not support LPIPS computation! '
|
123 |
+
'Equal weights will be used for each resolution.')
|
124 |
+
|
125 |
+
fingerprint = (model_source, no_top, enable_lpips)
|
126 |
+
|
127 |
+
if fingerprint not in PerceptualModel.models:
|
128 |
+
# Build model.
|
129 |
+
model = VGG16(align_tf_resize=align_tf_resize,
|
130 |
+
no_top=no_top,
|
131 |
+
enable_lpips=enable_lpips)
|
132 |
+
|
133 |
+
# Download pre-trained weights.
|
134 |
+
if dist.is_initialized() and dist.get_rank() != 0:
|
135 |
+
dist.barrier() # Download by chief.
|
136 |
+
|
137 |
+
url, sha256 = _MODEL_URL_SHA256[model_source]
|
138 |
+
filename = f'perceptual_model_{model_source}_{sha256}.pth'
|
139 |
+
model_path, hash_check = download_url(url,
|
140 |
+
filename=filename,
|
141 |
+
sha256=sha256)
|
142 |
+
if is_torch_script:
|
143 |
+
src_state_dict = torch.jit.load(model_path, map_location='cpu')
|
144 |
+
else:
|
145 |
+
src_state_dict = torch.load(model_path, map_location='cpu')
|
146 |
+
if hash_check is False:
|
147 |
+
warnings.warn(f'Hash check failed! The remote file from URL '
|
148 |
+
f'`{url}` may be changed, or the downloading is '
|
149 |
+
f'interrupted. The loaded perceptual model may '
|
150 |
+
f'have unexpected behavior.')
|
151 |
+
|
152 |
+
if dist.is_initialized() and dist.get_rank() == 0:
|
153 |
+
dist.barrier() # Wait for other replicas.
|
154 |
+
|
155 |
+
# Load weights.
|
156 |
+
dst_state_dict = _convert_weights(src_state_dict, model_source)
|
157 |
+
model.load_state_dict(dst_state_dict, strict=False)
|
158 |
+
del src_state_dict, dst_state_dict
|
159 |
+
|
160 |
+
# For inference only.
|
161 |
+
model.eval().requires_grad_(False).cuda()
|
162 |
+
PerceptualModel.models[fingerprint] = model
|
163 |
+
|
164 |
+
return PerceptualModel.models[fingerprint]
|
165 |
+
|
166 |
+
|
167 |
+
def _convert_weights(src_state_dict, model_source):
|
168 |
+
if model_source not in _MODEL_URL_SHA256:
|
169 |
+
raise ValueError(f'Invalid model source `{model_source}`!\n'
|
170 |
+
f'Sources allowed: {list(_MODEL_URL_SHA256.keys())}.')
|
171 |
+
if model_source == 'torchvision_official':
|
172 |
+
dst_to_src_var_mapping = {
|
173 |
+
'conv11.weight': 'features.0.weight',
|
174 |
+
'conv11.bias': 'features.0.bias',
|
175 |
+
'conv12.weight': 'features.2.weight',
|
176 |
+
'conv12.bias': 'features.2.bias',
|
177 |
+
'conv21.weight': 'features.5.weight',
|
178 |
+
'conv21.bias': 'features.5.bias',
|
179 |
+
'conv22.weight': 'features.7.weight',
|
180 |
+
'conv22.bias': 'features.7.bias',
|
181 |
+
'conv31.weight': 'features.10.weight',
|
182 |
+
'conv31.bias': 'features.10.bias',
|
183 |
+
'conv32.weight': 'features.12.weight',
|
184 |
+
'conv32.bias': 'features.12.bias',
|
185 |
+
'conv33.weight': 'features.14.weight',
|
186 |
+
'conv33.bias': 'features.14.bias',
|
187 |
+
'conv41.weight': 'features.17.weight',
|
188 |
+
'conv41.bias': 'features.17.bias',
|
189 |
+
'conv42.weight': 'features.19.weight',
|
190 |
+
'conv42.bias': 'features.19.bias',
|
191 |
+
'conv43.weight': 'features.21.weight',
|
192 |
+
'conv43.bias': 'features.21.bias',
|
193 |
+
'conv51.weight': 'features.24.weight',
|
194 |
+
'conv51.bias': 'features.24.bias',
|
195 |
+
'conv52.weight': 'features.26.weight',
|
196 |
+
'conv52.bias': 'features.26.bias',
|
197 |
+
'conv53.weight': 'features.28.weight',
|
198 |
+
'conv53.bias': 'features.28.bias',
|
199 |
+
'fc1.weight': 'classifier.0.weight',
|
200 |
+
'fc1.bias': 'classifier.0.bias',
|
201 |
+
'fc2.weight': 'classifier.3.weight',
|
202 |
+
'fc2.bias': 'classifier.3.bias',
|
203 |
+
'fc3.weight': 'classifier.6.weight',
|
204 |
+
'fc3.bias': 'classifier.6.bias',
|
205 |
+
}
|
206 |
+
elif model_source == 'vgg_perceptual_lpips':
|
207 |
+
src_state_dict = src_state_dict.state_dict()
|
208 |
+
dst_to_src_var_mapping = {
|
209 |
+
'conv11.weight': 'layers.conv1.weight',
|
210 |
+
'conv11.bias': 'layers.conv1.bias',
|
211 |
+
'conv12.weight': 'layers.conv2.weight',
|
212 |
+
'conv12.bias': 'layers.conv2.bias',
|
213 |
+
'conv21.weight': 'layers.conv3.weight',
|
214 |
+
'conv21.bias': 'layers.conv3.bias',
|
215 |
+
'conv22.weight': 'layers.conv4.weight',
|
216 |
+
'conv22.bias': 'layers.conv4.bias',
|
217 |
+
'conv31.weight': 'layers.conv5.weight',
|
218 |
+
'conv31.bias': 'layers.conv5.bias',
|
219 |
+
'conv32.weight': 'layers.conv6.weight',
|
220 |
+
'conv32.bias': 'layers.conv6.bias',
|
221 |
+
'conv33.weight': 'layers.conv7.weight',
|
222 |
+
'conv33.bias': 'layers.conv7.bias',
|
223 |
+
'conv41.weight': 'layers.conv8.weight',
|
224 |
+
'conv41.bias': 'layers.conv8.bias',
|
225 |
+
'conv42.weight': 'layers.conv9.weight',
|
226 |
+
'conv42.bias': 'layers.conv9.bias',
|
227 |
+
'conv43.weight': 'layers.conv10.weight',
|
228 |
+
'conv43.bias': 'layers.conv10.bias',
|
229 |
+
'conv51.weight': 'layers.conv11.weight',
|
230 |
+
'conv51.bias': 'layers.conv11.bias',
|
231 |
+
'conv52.weight': 'layers.conv12.weight',
|
232 |
+
'conv52.bias': 'layers.conv12.bias',
|
233 |
+
'conv53.weight': 'layers.conv13.weight',
|
234 |
+
'conv53.bias': 'layers.conv13.bias',
|
235 |
+
'fc1.weight': 'layers.fc1.weight',
|
236 |
+
'fc1.bias': 'layers.fc1.bias',
|
237 |
+
'fc2.weight': 'layers.fc2.weight',
|
238 |
+
'fc2.bias': 'layers.fc2.bias',
|
239 |
+
'fc3.weight': 'layers.fc3.weight',
|
240 |
+
'fc3.bias': 'layers.fc3.bias',
|
241 |
+
'lpips.0.weight': 'lpips0',
|
242 |
+
'lpips.1.weight': 'lpips1',
|
243 |
+
'lpips.2.weight': 'lpips2',
|
244 |
+
'lpips.3.weight': 'lpips3',
|
245 |
+
'lpips.4.weight': 'lpips4',
|
246 |
+
}
|
247 |
+
else:
|
248 |
+
raise NotImplementedError(f'Not implemented model source '
|
249 |
+
f'`{model_source}`!')
|
250 |
+
|
251 |
+
dst_state_dict = {}
|
252 |
+
for dst_name, src_name in dst_to_src_var_mapping.items():
|
253 |
+
if dst_name.startswith('lpips'):
|
254 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].unsqueeze(0)
|
255 |
+
else:
|
256 |
+
dst_state_dict[dst_name] = src_state_dict[src_name].clone()
|
257 |
+
return dst_state_dict
|
258 |
+
|
259 |
+
|
260 |
+
_IMG_MEAN = (0.485, 0.456, 0.406)
|
261 |
+
_IMG_STD = (0.229, 0.224, 0.225)
|
262 |
+
_ALLOWED_RETURN = [
|
263 |
+
'feature1', 'pool1', 'feature2', 'pool2', 'feature3', 'pool3', 'feature4',
|
264 |
+
'pool4', 'feature5', 'pool5', 'flatten', 'feature', 'logits', 'prediction',
|
265 |
+
'lpips'
|
266 |
+
]
|
267 |
+
|
268 |
+
# pylint: disable=missing-function-docstring
|
269 |
+
|
270 |
+
class VGG16(nn.Module):
|
271 |
+
"""Defines the VGG16 structure.
|
272 |
+
|
273 |
+
This model takes `RGB` images with data format `NCHW` as the raw inputs. The
|
274 |
+
pixel range are assumed to be [-1, 1].
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, align_tf_resize=False, no_top=True, enable_lpips=True):
|
278 |
+
"""Defines the network structure."""
|
279 |
+
super().__init__()
|
280 |
+
|
281 |
+
self.align_tf_resize = align_tf_resize
|
282 |
+
self.no_top = no_top
|
283 |
+
self.enable_lpips = enable_lpips
|
284 |
+
|
285 |
+
self.conv11 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
286 |
+
self.relu11 = nn.ReLU(inplace=True)
|
287 |
+
self.conv12 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
288 |
+
self.relu12 = nn.ReLU(inplace=True)
|
289 |
+
# output `feature1`, with shape [N, 64, 224, 224]
|
290 |
+
|
291 |
+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
292 |
+
# output `pool1`, with shape [N, 64, 112, 112]
|
293 |
+
|
294 |
+
self.conv21 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
295 |
+
self.relu21 = nn.ReLU(inplace=True)
|
296 |
+
self.conv22 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
297 |
+
self.relu22 = nn.ReLU(inplace=True)
|
298 |
+
# output `feature2`, with shape [N, 128, 112, 112]
|
299 |
+
|
300 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
301 |
+
# output `pool2`, with shape [N, 128, 56, 56]
|
302 |
+
|
303 |
+
self.conv31 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
304 |
+
self.relu31 = nn.ReLU(inplace=True)
|
305 |
+
self.conv32 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
306 |
+
self.relu32 = nn.ReLU(inplace=True)
|
307 |
+
self.conv33 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
308 |
+
self.relu33 = nn.ReLU(inplace=True)
|
309 |
+
# output `feature3`, with shape [N, 256, 56, 56]
|
310 |
+
|
311 |
+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
312 |
+
# output `pool3`, with shape [N,256, 28, 28]
|
313 |
+
|
314 |
+
self.conv41 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
315 |
+
self.relu41 = nn.ReLU(inplace=True)
|
316 |
+
self.conv42 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
317 |
+
self.relu42 = nn.ReLU(inplace=True)
|
318 |
+
self.conv43 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
319 |
+
self.relu43 = nn.ReLU(inplace=True)
|
320 |
+
# output `feature4`, with shape [N, 512, 28, 28]
|
321 |
+
|
322 |
+
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
|
323 |
+
# output `pool4`, with shape [N, 512, 14, 14]
|
324 |
+
|
325 |
+
self.conv51 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
326 |
+
self.relu51 = nn.ReLU(inplace=True)
|
327 |
+
self.conv52 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
328 |
+
self.relu52 = nn.ReLU(inplace=True)
|
329 |
+
self.conv53 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
330 |
+
self.relu53 = nn.ReLU(inplace=True)
|
331 |
+
# output `feature5`, with shape [N, 512, 14, 14]
|
332 |
+
|
333 |
+
self.pool5 = nn.MaxPool2d(kernel_size=2, stride=2)
|
334 |
+
# output `pool5`, with shape [N, 512, 7, 7]
|
335 |
+
|
336 |
+
if self.enable_lpips:
|
337 |
+
self.lpips = nn.ModuleList()
|
338 |
+
for idx, ch in enumerate([64, 128, 256, 512, 512]):
|
339 |
+
self.lpips.append(nn.Conv2d(ch, 1, kernel_size=1, bias=False))
|
340 |
+
self.lpips[idx].weight.data.copy_(torch.ones(1, ch, 1, 1))
|
341 |
+
|
342 |
+
if not self.no_top:
|
343 |
+
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
|
344 |
+
self.flatten = nn.Flatten(start_dim=1, end_dim=-1)
|
345 |
+
# output `flatten`, with shape [N, 25088]
|
346 |
+
|
347 |
+
self.fc1 = nn.Linear(512 * 7 * 7, 4096)
|
348 |
+
self.fc1_relu = nn.ReLU(inplace=True)
|
349 |
+
self.fc1_dropout = nn.Dropout(0.5, inplace=False)
|
350 |
+
self.fc2 = nn.Linear(4096, 4096)
|
351 |
+
self.fc2_relu = nn.ReLU(inplace=True)
|
352 |
+
self.fc2_dropout = nn.Dropout(0.5, inplace=False)
|
353 |
+
# output `feature`, with shape [N, 4096]
|
354 |
+
|
355 |
+
self.fc3 = nn.Linear(4096, 1000)
|
356 |
+
# output `logits`, with shape [N, 1000]
|
357 |
+
|
358 |
+
self.out = nn.Softmax(dim=1)
|
359 |
+
# output `softmax`, with shape [N, 1000]
|
360 |
+
|
361 |
+
img_mean = np.array(_IMG_MEAN).reshape((1, 3, 1, 1)).astype(np.float32)
|
362 |
+
img_std = np.array(_IMG_STD).reshape((1, 3, 1, 1)).astype(np.float32)
|
363 |
+
self.register_buffer('img_mean', torch.from_numpy(img_mean))
|
364 |
+
self.register_buffer('img_std', torch.from_numpy(img_std))
|
365 |
+
|
366 |
+
def forward(self,
|
367 |
+
x,
|
368 |
+
y=None,
|
369 |
+
*,
|
370 |
+
resize_input=False,
|
371 |
+
return_tensor='feature'):
|
372 |
+
return_tensor = return_tensor.lower()
|
373 |
+
if return_tensor not in _ALLOWED_RETURN:
|
374 |
+
raise ValueError(f'Invalid output tensor name `{return_tensor}` '
|
375 |
+
f'for perceptual model (VGG16)!\n'
|
376 |
+
f'Names allowed: {_ALLOWED_RETURN}.')
|
377 |
+
|
378 |
+
if return_tensor == 'lpips' and y is None:
|
379 |
+
raise ValueError('Two images are required for LPIPS computation, '
|
380 |
+
'but only one is received!')
|
381 |
+
|
382 |
+
if return_tensor == 'lpips':
|
383 |
+
assert x.shape == y.shape
|
384 |
+
x = torch.cat([x, y], dim=0)
|
385 |
+
features = []
|
386 |
+
|
387 |
+
if resize_input:
|
388 |
+
if self.align_tf_resize:
|
389 |
+
theta = torch.eye(2, 3).to(x)
|
390 |
+
theta[0, 2] += theta[0, 0] / x.shape[3] - theta[0, 0] / 224
|
391 |
+
theta[1, 2] += theta[1, 1] / x.shape[2] - theta[1, 1] / 224
|
392 |
+
theta = theta.unsqueeze(0).repeat(x.shape[0], 1, 1)
|
393 |
+
grid = F.affine_grid(theta,
|
394 |
+
size=(x.shape[0], x.shape[1], 224, 224),
|
395 |
+
align_corners=False)
|
396 |
+
x = F.grid_sample(x, grid,
|
397 |
+
mode='bilinear',
|
398 |
+
padding_mode='border',
|
399 |
+
align_corners=False)
|
400 |
+
else:
|
401 |
+
x = F.interpolate(x,
|
402 |
+
size=(224, 224),
|
403 |
+
mode='bilinear',
|
404 |
+
align_corners=False)
|
405 |
+
if x.shape[1] == 1:
|
406 |
+
x = x.repeat((1, 3, 1, 1))
|
407 |
+
|
408 |
+
x = (x + 1) / 2
|
409 |
+
x = (x - self.img_mean) / self.img_std
|
410 |
+
|
411 |
+
x = self.conv11(x)
|
412 |
+
x = self.relu11(x)
|
413 |
+
x = self.conv12(x)
|
414 |
+
x = self.relu12(x)
|
415 |
+
if return_tensor == 'feature1':
|
416 |
+
return x
|
417 |
+
if return_tensor == 'lpips':
|
418 |
+
features.append(x)
|
419 |
+
|
420 |
+
x = self.pool1(x)
|
421 |
+
if return_tensor == 'pool1':
|
422 |
+
return x
|
423 |
+
|
424 |
+
x = self.conv21(x)
|
425 |
+
x = self.relu21(x)
|
426 |
+
x = self.conv22(x)
|
427 |
+
x = self.relu22(x)
|
428 |
+
if return_tensor == 'feature2':
|
429 |
+
return x
|
430 |
+
if return_tensor == 'lpips':
|
431 |
+
features.append(x)
|
432 |
+
|
433 |
+
x = self.pool2(x)
|
434 |
+
if return_tensor == 'pool2':
|
435 |
+
return x
|
436 |
+
|
437 |
+
x = self.conv31(x)
|
438 |
+
x = self.relu31(x)
|
439 |
+
x = self.conv32(x)
|
440 |
+
x = self.relu32(x)
|
441 |
+
x = self.conv33(x)
|
442 |
+
x = self.relu33(x)
|
443 |
+
if return_tensor == 'feature3':
|
444 |
+
return x
|
445 |
+
if return_tensor == 'lpips':
|
446 |
+
features.append(x)
|
447 |
+
|
448 |
+
x = self.pool3(x)
|
449 |
+
if return_tensor == 'pool3':
|
450 |
+
return x
|
451 |
+
|
452 |
+
x = self.conv41(x)
|
453 |
+
x = self.relu41(x)
|
454 |
+
x = self.conv42(x)
|
455 |
+
x = self.relu42(x)
|
456 |
+
x = self.conv43(x)
|
457 |
+
x = self.relu43(x)
|
458 |
+
if return_tensor == 'feature4':
|
459 |
+
return x
|
460 |
+
if return_tensor == 'lpips':
|
461 |
+
features.append(x)
|
462 |
+
|
463 |
+
x = self.pool4(x)
|
464 |
+
if return_tensor == 'pool4':
|
465 |
+
return x
|
466 |
+
|
467 |
+
x = self.conv51(x)
|
468 |
+
x = self.relu51(x)
|
469 |
+
x = self.conv52(x)
|
470 |
+
x = self.relu52(x)
|
471 |
+
x = self.conv53(x)
|
472 |
+
x = self.relu53(x)
|
473 |
+
if return_tensor == 'feature5':
|
474 |
+
return x
|
475 |
+
if return_tensor == 'lpips':
|
476 |
+
features.append(x)
|
477 |
+
|
478 |
+
x = self.pool5(x)
|
479 |
+
if return_tensor == 'pool5':
|
480 |
+
return x
|
481 |
+
|
482 |
+
if return_tensor == 'lpips':
|
483 |
+
score = 0
|
484 |
+
assert len(features) == 5
|
485 |
+
for idx in range(5):
|
486 |
+
feature = features[idx]
|
487 |
+
norm = feature.norm(dim=1, keepdim=True)
|
488 |
+
feature = feature / (norm + 1e-10)
|
489 |
+
feature_x, feature_y = feature.chunk(2, dim=0)
|
490 |
+
diff = (feature_x - feature_y).square()
|
491 |
+
score += self.lpips[idx](diff).mean(dim=(2, 3), keepdim=False)
|
492 |
+
return score.sum(dim=1, keepdim=False)
|
493 |
+
|
494 |
+
x = self.avgpool(x)
|
495 |
+
x = self.flatten(x)
|
496 |
+
if return_tensor == 'flatten':
|
497 |
+
return x
|
498 |
+
|
499 |
+
x = self.fc1(x)
|
500 |
+
x = self.fc1_relu(x)
|
501 |
+
x = self.fc1_dropout(x)
|
502 |
+
x = self.fc2(x)
|
503 |
+
x = self.fc2_relu(x)
|
504 |
+
x = self.fc2_dropout(x)
|
505 |
+
if return_tensor == 'feature':
|
506 |
+
return x
|
507 |
+
|
508 |
+
x = self.fc3(x)
|
509 |
+
if return_tensor == 'logits':
|
510 |
+
return x
|
511 |
+
|
512 |
+
x = self.out(x)
|
513 |
+
if return_tensor == 'prediction':
|
514 |
+
return x
|
515 |
+
|
516 |
+
raise NotImplementedError(f'Output tensor name `{return_tensor}` is '
|
517 |
+
f'not implemented!')
|
518 |
+
|
519 |
+
# pylint: enable=missing-function-docstring
|