ThaiVecFont / models /image_encoder.py
microhum's picture
add dockerfile
667ab99
raw
history blame
1.64 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ImageEncoder(nn.Module):
def __init__(self, img_size, input_nc, ngf=16, norm_layer=nn.LayerNorm):
super(ImageEncoder, self).__init__()
n_downsampling = int(math.log(img_size, 2))
ks_list = [5] * (n_downsampling - n_downsampling // 3) + [3] * (n_downsampling // 3)
stride_list = [2] * n_downsampling
chn_mult = []
for i in range(n_downsampling):
chn_mult.append(2 ** (i + 1))
encoder = [nn.Conv2d(input_nc, ngf, kernel_size=7, padding=7 // 2, bias=True, padding_mode='replicate'),
norm_layer([ngf, 2 ** n_downsampling, 2 ** n_downsampling]),
nn.ReLU(True)]
for i in range(n_downsampling): # add downsampling layers
if i == 0:
chn_prev = ngf
else:
chn_prev = ngf * chn_mult[i - 1]
chn_next = ngf * chn_mult[i]
encoder += [nn.Conv2d(chn_prev, chn_next, kernel_size=ks_list[i], stride=stride_list[i], padding=ks_list[i] // 2, padding_mode='replicate'),
norm_layer([chn_next, 2 ** (n_downsampling - 1 - i), 2 ** (n_downsampling - 1 - i)]),
nn.ReLU(True)]
self.encode = nn.Sequential(*encoder)
self.flatten = nn.Flatten()
def forward(self, input):
"""Standard forward"""
ret = self.encode(input)
img_feat = self.flatten(ret)
output = {}
output['img_feat'] = img_feat
return output