svjack's picture
Upload SPIGA with huggingface_hub
9390e2c
raw
history blame
1.58 kB
import torch.nn as nn
from spiga.models.cnn.layers import Conv, Deconv, Residual
class Hourglass(nn.Module):
def __init__(self, n, f, bn=None, increase=0):
super(Hourglass, self).__init__()
nf = f + increase
self.up1 = Residual(f, f)
# Lower branch
self.pool1 = Conv(f, f, 2, 2, bn=True, relu=True)
self.low1 = Residual(f, nf)
self.n = n
# Recursive hourglass
if self.n > 1:
self.low2 = Hourglass(n - 1, nf, bn=bn)
else:
self.low2 = Residual(nf, nf)
self.low3 = Residual(nf, f)
self.up2 = Deconv(f, f, 2, 2, bn=True, relu=True)
def forward(self, x):
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
low2 = self.low2(low1)
low3 = self.low3(low2)
up2 = self.up2(low3)
return up1 + up2
class HourglassCore(Hourglass):
def __init__(self, n, f, bn=None, increase=0):
super(HourglassCore, self).__init__(n, f, bn=bn, increase=increase)
nf = f + increase
if self.n > 1:
self.low2 = HourglassCore(n - 1, nf, bn=bn)
def forward(self, x, core=[]):
up1 = self.up1(x)
pool1 = self.pool1(x)
low1 = self.low1(pool1)
if self.n > 1:
low2, core = self.low2(low1, core=core)
else:
low2 = self.low2(low1)
core.append(low2)
low3 = self.low3(low2)
if self.n > 1:
core.append(low3)
up2 = self.up2(low3)
return up1 + up2, core