Spaces:
Sleeping
Sleeping
File size: 1,596 Bytes
b762e56 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ImageEncoder(nn.Module):
def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm):
super(ImageEncoder, self).__init__()
n_downsampling = int(math.log(img_size, 2))
ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3)
stride_list = [2] * n_downsampling
chn_mult = []
for i in range(n_downsampling):
chn_mult.append(2 ** (i + 1))
encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'),
norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]),
nn.ReLU(True)]
for i in range(n_downsampling): # add downsampling layers
if i == 0:
chn_prev = ngf
else:
chn_prev = ngf * chn_mult[i - 1]
chn_next = ngf * chn_mult[i]
encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'),
norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]),
nn.ReLU(True)]
self.encode = nn.Sequential(*encoder)
self.flatten = nn.Flatten()
def forward(self, input):
"""Standard forward"""
ret = self.encode(input)
img_feat = self.flatten(ret)
output = {}
output['img_feat'] = img_feat
return output
|