face-swap / retinaface /models.py
felixrosberg's picture
with private models
69c590e
raw
history blame
10.6 kB
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.applications import MobileNetV2, ResNet50
from tensorflow.keras.layers import Input, Conv2D, ReLU, LeakyReLU
from retinaface.anchor import decode_tf, prior_box_tf
def _regularizer(weights_decay):
"""l2 regularizer"""
return tf.keras.regularizers.l2(weights_decay)
def _kernel_init(scale=1.0, seed=None):
"""He normal initializer"""
return tf.keras.initializers.he_normal()
class BatchNormalization(tf.keras.layers.BatchNormalization):
"""Make trainable=False freeze BN for real (the og version is sad).
ref: https://github.com/zzh8829/yolov3-tf2
"""
def __init__(self, axis=-1, momentum=0.9, epsilon=1e-5, center=True,
scale=True, name=None, **kwargs):
super(BatchNormalization, self).__init__(
axis=axis, momentum=momentum, epsilon=epsilon, center=center,
scale=scale, name=name, **kwargs)
def call(self, x, training=False):
if training is None:
training = tf.constant(False)
training = tf.logical_and(training, self.trainable)
return super().call(x, training)
def Backbone(backbone_type='ResNet50', use_pretrain=True):
"""Backbone Model"""
weights = None
if use_pretrain:
weights = 'imagenet'
def backbone(x):
if backbone_type == 'ResNet50':
extractor = ResNet50(
input_shape=x.shape[1:], include_top=False, weights=weights)
pick_layer1 = 80 # [80, 80, 512]
pick_layer2 = 142 # [40, 40, 1024]
pick_layer3 = 174 # [20, 20, 2048]
preprocess = tf.keras.applications.resnet.preprocess_input
elif backbone_type == 'MobileNetV2':
extractor = MobileNetV2(
input_shape=x.shape[1:], include_top=False, weights=weights)
pick_layer1 = 54 # [80, 80, 32]
pick_layer2 = 116 # [40, 40, 96]
pick_layer3 = 143 # [20, 20, 160]
preprocess = tf.keras.applications.mobilenet_v2.preprocess_input
else:
raise NotImplementedError(
'Backbone type {} is not recognized.'.format(backbone_type))
return Model(extractor.input,
(extractor.layers[pick_layer1].output,
extractor.layers[pick_layer2].output,
extractor.layers[pick_layer3].output),
name=backbone_type + '_extrator')(preprocess(x))
return backbone
class ConvUnit(tf.keras.layers.Layer):
"""Conv + BN + Act"""
def __init__(self, f, k, s, wd, act=None, **kwargs):
super(ConvUnit, self).__init__(**kwargs)
self.conv = Conv2D(filters=f, kernel_size=k, strides=s, padding='same',
kernel_initializer=_kernel_init(),
kernel_regularizer=_regularizer(wd),
use_bias=False)
self.bn = BatchNormalization()
if act is None:
self.act_fn = tf.identity
elif act == 'relu':
self.act_fn = ReLU()
elif act == 'lrelu':
self.act_fn = LeakyReLU(0.1)
else:
raise NotImplementedError(
'Activation function type {} is not recognized.'.format(act))
def call(self, x):
return self.act_fn(self.bn(self.conv(x)))
class FPN(tf.keras.layers.Layer):
"""Feature Pyramid Network"""
def __init__(self, out_ch, wd, **kwargs):
super(FPN, self).__init__(**kwargs)
act = 'relu'
self.out_ch = out_ch
self.wd = wd
if (out_ch <= 64):
act = 'lrelu'
self.output1 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.output2 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.output3 = ConvUnit(f=out_ch, k=1, s=1, wd=wd, act=act)
self.merge1 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
self.merge2 = ConvUnit(f=out_ch, k=3, s=1, wd=wd, act=act)
def call(self, x):
output1 = self.output1(x[0]) # [80, 80, out_ch]
output2 = self.output2(x[1]) # [40, 40, out_ch]
output3 = self.output3(x[2]) # [20, 20, out_ch]
up_h, up_w = tf.shape(output2)[1], tf.shape(output2)[2]
up3 = tf.image.resize(output3, [up_h, up_w], method='nearest')
output2 = output2 + up3
output2 = self.merge2(output2)
up_h, up_w = tf.shape(output1)[1], tf.shape(output1)[2]
up2 = tf.image.resize(output2, [up_h, up_w], method='nearest')
output1 = output1 + up2
output1 = self.merge1(output1)
return output1, output2, output3
def get_config(self):
config = {
'out_ch': self.out_ch,
'wd': self.wd,
}
base_config = super(FPN, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class SSH(tf.keras.layers.Layer):
"""Single Stage Headless Layer"""
def __init__(self, out_ch, wd, **kwargs):
super(SSH, self).__init__(**kwargs)
assert out_ch % 4 == 0
self.out_ch = out_ch
self.wd = wd
act = 'relu'
if (out_ch <= 64):
act = 'lrelu'
self.conv_3x3 = ConvUnit(f=out_ch // 2, k=3, s=1, wd=wd, act=None)
self.conv_5x5_1 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
self.conv_5x5_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
self.conv_7x7_2 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=act)
self.conv_7x7_3 = ConvUnit(f=out_ch // 4, k=3, s=1, wd=wd, act=None)
self.relu = ReLU()
def call(self, x):
conv_3x3 = self.conv_3x3(x)
conv_5x5_1 = self.conv_5x5_1(x)
conv_5x5 = self.conv_5x5_2(conv_5x5_1)
conv_7x7_2 = self.conv_7x7_2(conv_5x5_1)
conv_7x7 = self.conv_7x7_3(conv_7x7_2)
output = tf.concat([conv_3x3, conv_5x5, conv_7x7], axis=3)
output = self.relu(output)
return output
def get_config(self):
config = {
'out_ch': self.out_ch,
'wd': self.wd,
}
base_config = super(SSH, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class BboxHead(tf.keras.layers.Layer):
"""Bbox Head Layer"""
def __init__(self, num_anchor, wd, **kwargs):
super(BboxHead, self).__init__(**kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 4, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 4])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(BboxHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class LandmarkHead(tf.keras.layers.Layer):
"""Landmark Head Layer"""
def __init__(self, num_anchor, wd, name='LandmarkHead', **kwargs):
super(LandmarkHead, self).__init__(name=name, **kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 10, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 10])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(LandmarkHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
class ClassHead(tf.keras.layers.Layer):
"""Class Head Layer"""
def __init__(self, num_anchor, wd, name='ClassHead', **kwargs):
super(ClassHead, self).__init__(name=name, **kwargs)
self.num_anchor = num_anchor
self.wd = wd
self.conv = Conv2D(filters=num_anchor * 2, kernel_size=1, strides=1)
def call(self, x):
h, w = tf.shape(x)[1], tf.shape(x)[2]
x = self.conv(x)
return tf.reshape(x, [-1, h * w * self.num_anchor, 2])
def get_config(self):
config = {
'num_anchor': self.num_anchor,
'wd': self.wd,
}
base_config = super(ClassHead, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def RetinaFaceModel(cfg, training=False, iou_th=0.4, score_th=0.02,
name='RetinaFaceModel'):
"""Retina Face Model"""
input_size = cfg['input_size'] if training else None
wd = cfg['weights_decay']
out_ch = cfg['out_channel']
num_anchor = len(cfg['min_sizes'][0])
backbone_type = cfg['backbone_type']
# define model
x = inputs = Input([input_size, input_size, 3], name='input_image')
x = Backbone(backbone_type=backbone_type)(x)
fpn = FPN(out_ch=out_ch, wd=wd)(x)
features = [SSH(out_ch=out_ch, wd=wd)(f)
for i, f in enumerate(fpn)]
bbox_regressions = tf.concat(
[BboxHead(num_anchor, wd=wd)(f)
for i, f in enumerate(features)], axis=1)
landm_regressions = tf.concat(
[LandmarkHead(num_anchor, wd=wd, name=f'LandmarkHead_{i}')(f)
for i, f in enumerate(features)], axis=1)
classifications = tf.concat(
[ClassHead(num_anchor, wd=wd, name=f'ClassHead_{i}')(f)
for i, f in enumerate(features)], axis=1)
classifications = tf.keras.layers.Softmax(axis=-1)(classifications)
if training:
out = (bbox_regressions, landm_regressions, classifications)
else:
# only for batch size 1
preds = tf.concat( # [bboxes, landms, landms_valid, conf]
[bbox_regressions[0],
landm_regressions[0],
tf.ones_like(classifications[0, :, 0][..., tf.newaxis]),
classifications[0, :, 1][..., tf.newaxis]], 1)
priors = prior_box_tf((tf.shape(inputs)[1], tf.shape(inputs)[2]), cfg['min_sizes'], cfg['steps'], cfg['clip'])
decode_preds = decode_tf(preds, priors, cfg['variances'])
selected_indices = tf.image.non_max_suppression(
boxes=decode_preds[:, :4],
scores=decode_preds[:, -1],
max_output_size=tf.shape(decode_preds)[0],
iou_threshold=iou_th,
score_threshold=score_th)
out = tf.gather(decode_preds, selected_indices)
return Model(inputs, out, name=name), Model(inputs, [bbox_regressions, landm_regressions, classifications], name=name + '_bb_only')