gavinyuan
udpate: app.py import FSGenerator
a104d3f
raw
history blame
2.18 kB
import argparse
import pytorch_lightning as pl
import numpy as np
import torch
from third_party.arcface.mouth_net_pl import MouthNetPL
from third_party.arcface.mouth_net import MouthNet
class MouthTest(object):
def __init__(self):
self.dataset_len = 400
self.fixer_crop_param = (28, 56, 84, 112)
self.fixer_casia_model = MouthNet(
bisenet=None,
feature_dim=128,
crop_param=self.fixer_crop_param
).cuda()
fixer_path = "/gavin/code/FaceSwapping/modules/third_party/arcface/weights/fixer_net_casia_28_56_84_112.pth"
self.fixer_casia_model.load_backbone(fixer_path)
self.fixer_casia_model.eval()
self.fixer_t = np.zeros((self.dataset_len, 128), dtype=np.float32)
self.fixer_s = np.zeros_like(self.fixer_t, dtype=np.float32) # each embedding repeats 10 times in ffplus
self.fixer_r = np.zeros_like(self.fixer_t, dtype=np.float32)
print('Fixer model loaded.')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
args = parser.parse_args()
args.val_targets = []
args.rec_folder = "/gavin/datasets/msml/ms1m-retinaface"
fixer_net = MouthNetPL.load_from_checkpoint(
"/apdcephfs/share_1290939/gavinyuan/out/fixernet_casia/epoch=22-step=10999-v1.ckpt",
map_location='cpu', strict=False,
num_classes=10572,
batch_size=128,
dim_feature=128,
rec_folder=args.rec_folder,
header_type="AMCosFace",
crop=(28, 56, 84, 112),
)
lower_net_1 = MouthNetPL.load_from_checkpoint(
"/apdcephfs/share_1290939/gavinyuan/out/mouth_net_1/epoch=24-step=242999.ckpt",
map_location='cpu', strict=False,
num_classes=93431,
batch_size=128,
dim_feature=128,
rec_folder=args.rec_folder,
header_type="AMArcFace",
crop=(28, 56, 84, 112),
)
# test_net = fixer_net
test_net = lower_net_1
trainer = pl.Trainer(
logger=False,
gpus=1,
distributed_backend='dp',
benchmark=True,
)
trainer.test(test_net)
# print('Fixer model loading...')
# m_test = MouthTest()