Spanicin commited on
Commit
d3c3894
1 Parent(s): 4d7bc0c

Update videoretalking/models/ENet.py

Browse files
Files changed (1) hide show
  1. videoretalking/models/ENet.py +138 -138
videoretalking/models/ENet.py CHANGED
@@ -1,139 +1,139 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from models.base_blocks import ResBlock, StyleConv, ToRGB
6
-
7
-
8
- class ENet(nn.Module):
9
- def __init__(
10
- self,
11
- num_style_feat=512,
12
- lnet=None,
13
- concat=False
14
- ):
15
- super(ENet, self).__init__()
16
-
17
- self.low_res = lnet
18
- for param in self.low_res.parameters():
19
- param.requires_grad = False
20
-
21
- channel_multiplier, narrow = 2, 1
22
- channels = {
23
- '4': int(512 * narrow),
24
- '8': int(512 * narrow),
25
- '16': int(512 * narrow),
26
- '32': int(512 * narrow),
27
- '64': int(256 * channel_multiplier * narrow),
28
- '128': int(128 * channel_multiplier * narrow),
29
- '256': int(64 * channel_multiplier * narrow),
30
- '512': int(32 * channel_multiplier * narrow),
31
- '1024': int(16 * channel_multiplier * narrow)
32
- }
33
-
34
- self.log_size = 8
35
- first_out_size = 128
36
- self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
37
-
38
- # downsample
39
- in_channels = channels[f'{first_out_size}']
40
- self.conv_body_down = nn.ModuleList()
41
- for i in range(8, 2, -1):
42
- out_channels = channels[f'{2**(i - 1)}']
43
- self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
44
- in_channels = out_channels
45
-
46
- self.num_style_feat = num_style_feat
47
- linear_out_channel = num_style_feat
48
- self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
49
- self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
50
-
51
- self.style_convs = nn.ModuleList()
52
- self.to_rgbs = nn.ModuleList()
53
- self.noises = nn.Module()
54
-
55
- self.concat = concat
56
- if concat:
57
- in_channels = 3 + 32 # channels['64']
58
- else:
59
- in_channels = 3
60
-
61
- for i in range(7, 9): # 128, 256
62
- out_channels = channels[f'{2**i}'] #
63
- self.style_convs.append(
64
- StyleConv(
65
- in_channels,
66
- out_channels,
67
- kernel_size=3,
68
- num_style_feat=num_style_feat,
69
- demodulate=True,
70
- sample_mode='upsample'))
71
- self.style_convs.append(
72
- StyleConv(
73
- out_channels,
74
- out_channels,
75
- kernel_size=3,
76
- num_style_feat=num_style_feat,
77
- demodulate=True,
78
- sample_mode=None))
79
- self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
80
- in_channels = out_channels
81
-
82
- def forward(self, audio_sequences, face_sequences, gt_sequences):
83
- B = audio_sequences.size(0)
84
- input_dim_size = len(face_sequences.size())
85
- inp, ref = torch.split(face_sequences,3,dim=1)
86
-
87
- if input_dim_size > 4:
88
- audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
89
- inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
90
- ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
91
- gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
92
-
93
- # get the global style
94
- feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
95
- for i in range(self.log_size - 2):
96
- feat = self.conv_body_down[i](feat)
97
- feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
98
-
99
- # style code
100
- style_code = self.final_linear(feat.reshape(feat.size(0), -1))
101
- style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
102
-
103
- LNet_input = torch.cat([inp, gt_sequences], dim=1)
104
- LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
105
-
106
- if self.concat:
107
- low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
108
- low_res_img.detach()
109
- low_res_feat.detach()
110
- out = torch.cat([low_res_img, low_res_feat], dim=1)
111
-
112
- else:
113
- low_res_img = self.low_res(audio_sequences, LNet_input)
114
- low_res_img.detach()
115
- # 96 x 96
116
- out = low_res_img
117
-
118
- p2d = (2,2,2,2)
119
- out = F.pad(out, p2d, "reflect", 0)
120
- skip = out
121
-
122
- for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
123
- out = conv1(out, style_code) # 96, 192, 384
124
- out = conv2(out, style_code)
125
- skip = to_rgb(out, style_code, skip)
126
- _outputs = skip
127
-
128
- # remove padding
129
- _outputs = _outputs[:,:,8:-8,8:-8]
130
-
131
- if input_dim_size > 4:
132
- _outputs = torch.split(_outputs, B, dim=0)
133
- outputs = torch.stack(_outputs, dim=2)
134
- low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
135
- low_res_img = torch.split(low_res_img, B, dim=0)
136
- low_res_img = torch.stack(low_res_img, dim=2)
137
- else:
138
- outputs = _outputs
139
  return outputs, low_res_img
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from videoretalking.models.base_blocks import ResBlock, StyleConv, ToRGB
6
+
7
+
8
+ class ENet(nn.Module):
9
+ def __init__(
10
+ self,
11
+ num_style_feat=512,
12
+ lnet=None,
13
+ concat=False
14
+ ):
15
+ super(ENet, self).__init__()
16
+
17
+ self.low_res = lnet
18
+ for param in self.low_res.parameters():
19
+ param.requires_grad = False
20
+
21
+ channel_multiplier, narrow = 2, 1
22
+ channels = {
23
+ '4': int(512 * narrow),
24
+ '8': int(512 * narrow),
25
+ '16': int(512 * narrow),
26
+ '32': int(512 * narrow),
27
+ '64': int(256 * channel_multiplier * narrow),
28
+ '128': int(128 * channel_multiplier * narrow),
29
+ '256': int(64 * channel_multiplier * narrow),
30
+ '512': int(32 * channel_multiplier * narrow),
31
+ '1024': int(16 * channel_multiplier * narrow)
32
+ }
33
+
34
+ self.log_size = 8
35
+ first_out_size = 128
36
+ self.conv_body_first = nn.Conv2d(3, channels[f'{first_out_size}'], 1) # 256 -> 128
37
+
38
+ # downsample
39
+ in_channels = channels[f'{first_out_size}']
40
+ self.conv_body_down = nn.ModuleList()
41
+ for i in range(8, 2, -1):
42
+ out_channels = channels[f'{2**(i - 1)}']
43
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode='down'))
44
+ in_channels = out_channels
45
+
46
+ self.num_style_feat = num_style_feat
47
+ linear_out_channel = num_style_feat
48
+ self.final_linear = nn.Linear(channels['4'] * 4 * 4, linear_out_channel)
49
+ self.final_conv = nn.Conv2d(in_channels, channels['4'], 3, 1, 1)
50
+
51
+ self.style_convs = nn.ModuleList()
52
+ self.to_rgbs = nn.ModuleList()
53
+ self.noises = nn.Module()
54
+
55
+ self.concat = concat
56
+ if concat:
57
+ in_channels = 3 + 32 # channels['64']
58
+ else:
59
+ in_channels = 3
60
+
61
+ for i in range(7, 9): # 128, 256
62
+ out_channels = channels[f'{2**i}'] #
63
+ self.style_convs.append(
64
+ StyleConv(
65
+ in_channels,
66
+ out_channels,
67
+ kernel_size=3,
68
+ num_style_feat=num_style_feat,
69
+ demodulate=True,
70
+ sample_mode='upsample'))
71
+ self.style_convs.append(
72
+ StyleConv(
73
+ out_channels,
74
+ out_channels,
75
+ kernel_size=3,
76
+ num_style_feat=num_style_feat,
77
+ demodulate=True,
78
+ sample_mode=None))
79
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
80
+ in_channels = out_channels
81
+
82
+ def forward(self, audio_sequences, face_sequences, gt_sequences):
83
+ B = audio_sequences.size(0)
84
+ input_dim_size = len(face_sequences.size())
85
+ inp, ref = torch.split(face_sequences,3,dim=1)
86
+
87
+ if input_dim_size > 4:
88
+ audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
89
+ inp = torch.cat([inp[:, :, i] for i in range(inp.size(2))], dim=0)
90
+ ref = torch.cat([ref[:, :, i] for i in range(ref.size(2))], dim=0)
91
+ gt_sequences = torch.cat([gt_sequences[:, :, i] for i in range(gt_sequences.size(2))], dim=0)
92
+
93
+ # get the global style
94
+ feat = F.leaky_relu_(self.conv_body_first(F.interpolate(ref, size=(256,256), mode='bilinear')), negative_slope=0.2)
95
+ for i in range(self.log_size - 2):
96
+ feat = self.conv_body_down[i](feat)
97
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
98
+
99
+ # style code
100
+ style_code = self.final_linear(feat.reshape(feat.size(0), -1))
101
+ style_code = style_code.reshape(style_code.size(0), -1, self.num_style_feat)
102
+
103
+ LNet_input = torch.cat([inp, gt_sequences], dim=1)
104
+ LNet_input = F.interpolate(LNet_input, size=(96,96), mode='bilinear')
105
+
106
+ if self.concat:
107
+ low_res_img, low_res_feat = self.low_res(audio_sequences, LNet_input)
108
+ low_res_img.detach()
109
+ low_res_feat.detach()
110
+ out = torch.cat([low_res_img, low_res_feat], dim=1)
111
+
112
+ else:
113
+ low_res_img = self.low_res(audio_sequences, LNet_input)
114
+ low_res_img.detach()
115
+ # 96 x 96
116
+ out = low_res_img
117
+
118
+ p2d = (2,2,2,2)
119
+ out = F.pad(out, p2d, "reflect", 0)
120
+ skip = out
121
+
122
+ for conv1, conv2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], self.to_rgbs):
123
+ out = conv1(out, style_code) # 96, 192, 384
124
+ out = conv2(out, style_code)
125
+ skip = to_rgb(out, style_code, skip)
126
+ _outputs = skip
127
+
128
+ # remove padding
129
+ _outputs = _outputs[:,:,8:-8,8:-8]
130
+
131
+ if input_dim_size > 4:
132
+ _outputs = torch.split(_outputs, B, dim=0)
133
+ outputs = torch.stack(_outputs, dim=2)
134
+ low_res_img = F.interpolate(low_res_img, outputs.size()[3:])
135
+ low_res_img = torch.split(low_res_img, B, dim=0)
136
+ low_res_img = torch.stack(low_res_img, dim=2)
137
+ else:
138
+ outputs = _outputs
139
  return outputs, low_res_img