diff --git a/scripts/download_models.sh b/scripts/download_models.sh
new file mode 100644
index 0000000000000000000000000000000000000000..dab0a435560847129729334547e355909062ace2
--- /dev/null
+++ b/scripts/download_models.sh
@@ -0,0 +1,22 @@
+mkdir ./checkpoints  
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2exp_00300-model.pth -O ./checkpoints/auido2exp_00300-model.pth
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/auido2pose_00140-model.pth -O ./checkpoints/auido2pose_00140-model.pth
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/epoch_20.pth -O ./checkpoints/epoch_20.pth
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/facevid2vid_00189-model.pth.tar -O ./checkpoints/facevid2vid_00189-model.pth.tar
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/shape_predictor_68_face_landmarks.dat -O ./checkpoints/shape_predictor_68_face_landmarks.dat
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/wav2lip.pth -O ./checkpoints/wav2lip.pth
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00229-model.pth.tar -O ./checkpoints/mapping_00229-model.pth.tar
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/mapping_00109-model.pth.tar -O ./checkpoints/mapping_00109-model.pth.tar
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/BFM_Fitting.zip -O ./checkpoints/BFM_Fitting.zip
+wget -nc https://github.com/Winfredy/SadTalker/releases/download/v0.0.2/hub.zip -O ./checkpoints/hub.zip
+
+
+unzip -n ./checkpoints/hub.zip -d ./checkpoints/
+unzip -n ./checkpoints/BFM_Fitting.zip -d ./checkpoints/
+
+mkdir -p ./gfpgan/weights
+wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth -O ./gfpgan/weights/alignment_WFLW_4HG.pth 
+wget -nc https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth -O ./gfpgan/weights/detection_Resnet50_Final.pth 
+wget -nc https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth -O ./gfpgan/weights/GFPGANv1.4.pth 
+wget -nc https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth -O ./gfpgan/weights/parsing_parsenet.pth 
+
diff --git a/scripts/extension.py b/scripts/extension.py
new file mode 100644
index 0000000000000000000000000000000000000000..8cd9ca892069326a9b462c116445fe4d79fa8dfb
--- /dev/null
+++ b/scripts/extension.py
@@ -0,0 +1,209 @@
+import os, sys
+from pathlib import Path
+import tempfile
+import gradio as gr
+from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call
+from modules.shared import opts, OptionInfo
+from modules import shared, paths, script_callbacks
+import launch
+import glob
+from huggingface_hub import snapshot_download
+
+
+def check_all_files(current_dir):
+    kv = {
+        "auido2exp_00300-model.pth": "audio2exp",
+        "auido2pose_00140-model.pth": "audio2pose",
+        "epoch_20.pth": "face_recon",
+        "facevid2vid_00189-model.pth.tar": "face-render",
+        "mapping_00109-model.pth.tar" : "mapping-109" ,
+        "mapping_00229-model.pth.tar" : "mapping-229" ,
+        "wav2lip.pth": "wav2lip",
+        "shape_predictor_68_face_landmarks.dat": "dlib",
+    }
+
+    if not os.path.isdir(current_dir):
+        return False
+    
+    dirs = os.listdir(current_dir)
+
+    for f in dirs:
+        if f in kv.keys():
+            del kv[f]
+
+    return len(kv.keys()) == 0
+
+    
+
+def download_model(local_dir='./checkpoints'):
+    REPO_ID = 'vinthony/SadTalker'
+    snapshot_download(repo_id=REPO_ID, local_dir=local_dir, local_dir_use_symlinks=False)
+
+def get_source_image(image):   
+        return image
+
+def get_img_from_txt2img(x):
+    talker_path = Path(paths.script_path) / "outputs"
+    imgs_from_txt_dir = str(talker_path / "txt2img-images/")
+    imgs = glob.glob(imgs_from_txt_dir+'/*/*.png')
+    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_txt_dir, x)))
+    img_from_txt_path = os.path.join(imgs_from_txt_dir, imgs[-1])
+    return img_from_txt_path, img_from_txt_path
+
+def get_img_from_img2img(x):
+    talker_path = Path(paths.script_path) / "outputs"
+    imgs_from_img_dir = str(talker_path / "img2img-images/")
+    imgs = glob.glob(imgs_from_img_dir+'/*/*.png')
+    imgs.sort(key=lambda x:os.path.getmtime(os.path.join(imgs_from_img_dir, x)))
+    img_from_img_path = os.path.join(imgs_from_img_dir, imgs[-1])
+    return img_from_img_path, img_from_img_path
+ 
+def get_default_checkpoint_path():
+    # check the path of models/checkpoints and extensions/
+    checkpoint_path = Path(paths.script_path) / "models"/ "SadTalker" 
+    extension_checkpoint_path = Path(paths.script_path) / "extensions"/ "SadTalker" / "checkpoints"
+
+    if check_all_files(checkpoint_path):
+        # print('founding sadtalker checkpoint in ' + str(checkpoint_path))
+        return checkpoint_path
+
+    if check_all_files(extension_checkpoint_path):
+        # print('founding sadtalker checkpoint in ' + str(extension_checkpoint_path))
+        return extension_checkpoint_path
+
+    return None
+
+
+
+def install():
+
+    kv = {
+        "face_alignment": "face-alignment==1.3.5",
+        "imageio": "imageio==2.19.3",
+        "imageio_ffmpeg": "imageio-ffmpeg==0.4.7",
+        "librosa":"librosa==0.8.0",
+        "pydub":"pydub==0.25.1",
+        "scipy":"scipy==1.8.1",
+        "tqdm": "tqdm",
+        "yacs":"yacs==0.1.8",
+        "yaml": "pyyaml", 
+        "av":"av",
+        "gfpgan": "gfpgan",
+    }
+
+    if 'darwin' in sys.platform:
+        kv['dlib'] = "dlib"
+    else:
+        kv['dlib'] = 'dlib-bin'
+
+    for k,v in kv.items():
+        if not launch.is_installed(k):
+            print(k, launch.is_installed(k))
+            launch.run_pip("install "+ v, "requirements for SadTalker")
+
+
+    if os.getenv('SADTALKER_CHECKPOINTS'):
+        print('load Sadtalker Checkpoints from '+ os.getenv('SADTALKER_CHECKPOINTS'))
+
+    elif get_default_checkpoint_path() is not None:
+        os.environ['SADTALKER_CHECKPOINTS'] = str(get_default_checkpoint_path())
+    else:
+
+        print(
+            """"
+            SadTalker will not support download all the files from hugging face, which will take a long time.
+             
+            please manually set the SADTALKER_CHECKPOINTS in `webui_user.bat`(windows) or `webui_user.sh`(linux)
+            """
+            )
+        
+        # python = sys.executable
+
+        # launch.run(f'"{python}" -m pip uninstall -y huggingface_hub', live=True)
+        # launch.run(f'"{python}" -m pip install --upgrade git+https://github.com/huggingface/huggingface_hub@main', live=True)
+        # ### run the scripts to downlod models to correct localtion.
+        # # print('download models for SadTalker')
+        # # launch.run("cd " + paths.script_path+"/extensions/SadTalker && bash ./scripts/download_models.sh", live=True)
+        # # print('SadTalker is successfully installed!')
+        # download_model(paths.script_path+'/extensions/SadTalker/checkpoints')
+    
+ 
+def on_ui_tabs():
+    install()
+
+    sys.path.extend([paths.script_path+'/extensions/SadTalker']) 
+    
+    repo_dir = paths.script_path+'/extensions/SadTalker/'
+
+    result_dir = opts.sadtalker_result_dir
+    os.makedirs(result_dir, exist_ok=True)
+
+    from src.gradio_demo import SadTalker  
+
+    if  os.getenv('SADTALKER_CHECKPOINTS'):
+        checkpoint_path = os.getenv('SADTALKER_CHECKPOINTS')
+    else:
+        checkpoint_path = repo_dir+'checkpoints/'
+
+    sad_talker = SadTalker(checkpoint_path=checkpoint_path, config_path=repo_dir+'src/config', lazy_load=True)
+    
+    with gr.Blocks(analytics_enabled=False) as audio_to_video:
+        with gr.Row().style(equal_height=False):
+            with gr.Column(variant='panel'):
+                with gr.Tabs(elem_id="sadtalker_source_image"):
+                    with gr.TabItem('Upload image'):
+                        with gr.Row():
+                            input_image = gr.Image(label="Source image", source="upload", type="filepath").style(height=256,width=256)
+                        
+                        with gr.Row():
+                            submit_image2 = gr.Button('load From txt2img', variant='primary')
+                            submit_image2.click(fn=get_img_from_txt2img, inputs=input_image, outputs=[input_image, input_image])
+                            
+                            submit_image3 = gr.Button('load from img2img', variant='primary')
+                            submit_image3.click(fn=get_img_from_img2img, inputs=input_image, outputs=[input_image, input_image])
+
+                with gr.Tabs(elem_id="sadtalker_driven_audio"):
+                    with gr.TabItem('Upload'):
+                        with gr.Column(variant='panel'):
+
+                            with gr.Row():
+                                driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
+                        
+
+            with gr.Column(variant='panel'): 
+                with gr.Tabs(elem_id="sadtalker_checkbox"):
+                    with gr.TabItem('Settings'):
+                        with gr.Column(variant='panel'):
+                            gr.Markdown("Please visit [**[here]**](https://github.com/Winfredy/SadTalker/blob/main/docs/best_practice.md) if you don't know how to choose these configurations.")
+                            preprocess_type = gr.Radio(['crop','resize','full'], value='crop', label='preprocess', info="How to handle input image?")
+                            is_still_mode = gr.Checkbox(label="Remove head motion (works better with preprocess `full`)")
+                            enhancer = gr.Checkbox(label="Face enhancement")
+                            submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
+                            path_to_save = gr.Text(Path(paths.script_path) / "outputs/SadTalker/", visible=False)
+
+                with gr.Tabs(elem_id="sadtalker_genearted"):
+                        gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
+
+
+        ### gradio gpu call will always return the html, 
+        submit.click(
+                    fn=wrap_queued_call(sad_talker.test), 
+                    inputs=[input_image,
+                            driven_audio,
+                            preprocess_type,
+                            is_still_mode,
+                            enhancer,
+                            path_to_save
+                            ], 
+                    outputs=[gen_video, ]
+                    )
+
+    return [(audio_to_video, "SadTalker", "extension")]
+
+def on_ui_settings():
+    talker_path = Path(paths.script_path) / "outputs"
+    section = ('extension', "SadTalker") 
+    opts.add_option("sadtalker_result_dir", OptionInfo(str(talker_path / "SadTalker/"), "Path to save results of sadtalker", section=section)) 
+
+script_callbacks.on_ui_settings(on_ui_settings)
+script_callbacks.on_ui_tabs(on_ui_tabs)
diff --git a/scripts/test.sh b/scripts/test.sh
new file mode 100644
index 0000000000000000000000000000000000000000..69147611ead3c6c2e45bdd6daea3becb216f501c
--- /dev/null
+++ b/scripts/test.sh
@@ -0,0 +1 @@
+### some test command before commit.
\ No newline at end of file
diff --git a/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2f9f6dade6257863679600e6ede3827af7f3c18f
Binary files /dev/null and b/src/audio2exp_models/__pycache__/audio2exp.cpython-38.pyc differ
diff --git a/src/audio2exp_models/__pycache__/networks.cpython-38.pyc b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c1a90d0d0095e8f715ad1d38d9036204f55a2b97
Binary files /dev/null and b/src/audio2exp_models/__pycache__/networks.cpython-38.pyc differ
diff --git a/src/audio2exp_models/audio2exp.py b/src/audio2exp_models/audio2exp.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1062ab6684df01e0b3c48b6b577cc8df0503c91
--- /dev/null
+++ b/src/audio2exp_models/audio2exp.py
@@ -0,0 +1,41 @@
+from tqdm import tqdm
+import torch
+from torch import nn
+
+
+class Audio2Exp(nn.Module):
+    def __init__(self, netG, cfg, device, prepare_training_loss=False):
+        super(Audio2Exp, self).__init__()
+        self.cfg = cfg
+        self.device = device
+        self.netG = netG.to(device)
+
+    def test(self, batch):
+
+        mel_input = batch['indiv_mels']                         # bs T 1 80 16
+        bs = mel_input.shape[0]
+        T = mel_input.shape[1]
+
+        exp_coeff_pred = []
+
+        for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
+            
+            current_mel_input = mel_input[:,i:i+10]
+
+            #ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1))           #bs T 64
+            ref = batch['ref'][:, :, :64][:, i:i+10]
+            ratio = batch['ratio_gt'][:, i:i+10]                               #bs T
+
+            audiox = current_mel_input.view(-1, 1, 80, 16)                  # bs*T 1 80 16
+
+            curr_exp_coeff_pred  = self.netG(audiox, ref, ratio)         # bs T 64 
+
+            exp_coeff_pred += [curr_exp_coeff_pred]
+
+        # BS x T x 64
+        results_dict = {
+            'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
+            }
+        return results_dict
+
+
diff --git a/src/audio2exp_models/networks.py b/src/audio2exp_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd77a2f48d7c00ce85fe2eefe3a3e820730fbb74
--- /dev/null
+++ b/src/audio2exp_models/networks.py
@@ -0,0 +1,74 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class Conv2d(nn.Module):
+    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.conv_block = nn.Sequential(
+                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
+                            nn.BatchNorm2d(cout)
+                            )
+        self.act = nn.ReLU()
+        self.residual = residual
+        self.use_act = use_act
+
+    def forward(self, x):
+        out = self.conv_block(x)
+        if self.residual:
+            out += x
+        
+        if self.use_act:
+            return self.act(out)
+        else:
+            return out
+
+class SimpleWrapperV2(nn.Module):
+    def __init__(self) -> None:
+        super().__init__()
+        self.audio_encoder = nn.Sequential(
+            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
+            )
+
+        #### load the pre-trained audio_encoder 
+        #self.audio_encoder = self.audio_encoder.to(device)  
+        '''
+        wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
+        state_dict = self.audio_encoder.state_dict()
+
+        for k,v in wav2lip_state_dict.items():
+            if 'audio_encoder' in k:
+                print('init:', k)
+                state_dict[k.replace('module.audio_encoder.', '')] = v
+        self.audio_encoder.load_state_dict(state_dict)
+        '''
+
+        self.mapping1 = nn.Linear(512+64+1, 64)
+        #self.mapping2 = nn.Linear(30, 64)
+        #nn.init.constant_(self.mapping1.weight, 0.)
+        nn.init.constant_(self.mapping1.bias, 0.)
+
+    def forward(self, x, ref, ratio):
+        x = self.audio_encoder(x).view(x.size(0), -1)
+        ref_reshape = ref.reshape(x.size(0), -1)
+        ratio = ratio.reshape(x.size(0), -1)
+        
+        y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1)) 
+        out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
+        return out
diff --git a/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c61613a5ff71372ab102841c600fe4b45b00f201
Binary files /dev/null and b/src/audio2pose_models/__pycache__/audio2pose.cpython-38.pyc differ
diff --git a/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8c597f141934125b980a75caf2313dee02de7224
Binary files /dev/null and b/src/audio2pose_models/__pycache__/audio_encoder.cpython-38.pyc differ
diff --git a/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b54857612bf5ea8e2a68775f65b1c5e6db4c35e5
Binary files /dev/null and b/src/audio2pose_models/__pycache__/cvae.cpython-38.pyc differ
diff --git a/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..481f3eb6e72cb5be8f8c91fff6ef10465d94f860
Binary files /dev/null and b/src/audio2pose_models/__pycache__/discriminator.cpython-38.pyc differ
diff --git a/src/audio2pose_models/__pycache__/networks.cpython-38.pyc b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ce23c0a327f62d4704d3030c0d3d6cdc6e70363d
Binary files /dev/null and b/src/audio2pose_models/__pycache__/networks.cpython-38.pyc differ
diff --git a/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..78afad22ff783f4db56488fc034253b62fa88c3b
Binary files /dev/null and b/src/audio2pose_models/__pycache__/res_unet.cpython-38.pyc differ
diff --git a/src/audio2pose_models/audio2pose.py b/src/audio2pose_models/audio2pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc956b83ad4aa54d5366d73cbd0c7a563ef5d05c
--- /dev/null
+++ b/src/audio2pose_models/audio2pose.py
@@ -0,0 +1,94 @@
+import torch
+from torch import nn
+from src.audio2pose_models.cvae import CVAE
+from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
+from src.audio2pose_models.audio_encoder import AudioEncoder
+
+class Audio2Pose(nn.Module):
+    def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
+        super().__init__()
+        self.cfg = cfg
+        self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
+        self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
+        self.device = device
+
+        self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
+        self.audio_encoder.eval()
+        for param in self.audio_encoder.parameters():
+            param.requires_grad = False
+
+        self.netG = CVAE(cfg)
+        self.netD_motion = PoseSequenceDiscriminator(cfg)
+        
+        
+    def forward(self, x):
+
+        batch = {}
+        coeff_gt = x['gt'].cuda().squeeze(0)           #bs frame_len+1 73
+        batch['pose_motion_gt'] = coeff_gt[:, 1:, -9:-3] - coeff_gt[:, :1, -9:-3] #bs frame_len 6
+        batch['ref'] = coeff_gt[:, 0, -9:-3]  #bs  6
+        batch['class'] = x['class'].squeeze(0).cuda() # bs
+        indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
+
+        # forward
+        audio_emb_list = []
+        audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
+        batch['audio_emb'] = audio_emb
+        batch = self.netG(batch)
+
+        pose_motion_pred = batch['pose_motion_pred']           # bs frame_len 6
+        pose_gt = coeff_gt[:, 1:, -9:-3].clone()               # bs frame_len 6
+        pose_pred = coeff_gt[:, :1, -9:-3] + pose_motion_pred  # bs frame_len 6
+
+        batch['pose_pred'] = pose_pred
+        batch['pose_gt'] = pose_gt
+
+        return batch
+
+    def test(self, x):
+
+        batch = {}
+        ref = x['ref']                            #bs 1 70
+        batch['ref'] = x['ref'][:,0,-6:]  
+        batch['class'] = x['class']  
+        bs = ref.shape[0]
+        
+        indiv_mels= x['indiv_mels']               # bs T 1 80 16
+        indiv_mels_use = indiv_mels[:, 1:]        # we regard the ref as the first frame
+        num_frames = x['num_frames']
+        num_frames = int(num_frames) - 1
+
+        #  
+        div = num_frames//self.seq_len
+        re = num_frames%self.seq_len
+        audio_emb_list = []
+        pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype, 
+                                                device=batch['ref'].device)]
+
+        for i in range(div):
+            z = torch.randn(bs, self.latent_dim).to(ref.device)
+            batch['z'] = z
+            audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
+            batch['audio_emb'] = audio_emb
+            batch = self.netG.test(batch)
+            pose_motion_pred_list.append(batch['pose_motion_pred'])  #list of bs seq_len 6
+        
+        if re != 0:
+            z = torch.randn(bs, self.latent_dim).to(ref.device)
+            batch['z'] = z
+            audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len  512
+            if audio_emb.shape[1] != self.seq_len:
+                pad_dim = self.seq_len-audio_emb.shape[1]
+                pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1) 
+                audio_emb = torch.cat([pad_audio_emb, audio_emb], 1) 
+            batch['audio_emb'] = audio_emb
+            batch = self.netG.test(batch)
+            pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])   
+        
+        pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
+        batch['pose_motion_pred'] = pose_motion_pred
+
+        pose_pred = ref[:, :1, -6:] + pose_motion_pred  # bs T 6
+
+        batch['pose_pred'] = pose_pred
+        return batch
diff --git a/src/audio2pose_models/audio_encoder.py b/src/audio2pose_models/audio_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..424f68b61bfb094bb0eedb4b84645ff4dbbb105b
--- /dev/null
+++ b/src/audio2pose_models/audio_encoder.py
@@ -0,0 +1,64 @@
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+class Conv2d(nn.Module):
+    def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+        self.conv_block = nn.Sequential(
+                            nn.Conv2d(cin, cout, kernel_size, stride, padding),
+                            nn.BatchNorm2d(cout)
+                            )
+        self.act = nn.ReLU()
+        self.residual = residual
+
+    def forward(self, x):
+        out = self.conv_block(x)
+        if self.residual:
+            out += x
+        return self.act(out)
+
+class AudioEncoder(nn.Module):
+    def __init__(self, wav2lip_checkpoint, device):
+        super(AudioEncoder, self).__init__()
+
+        self.audio_encoder = nn.Sequential(
+            Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
+            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
+            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
+            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+            Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
+            Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
+
+            Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
+            Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
+
+        #### load the pre-trained audio_encoder
+        wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
+        state_dict = self.audio_encoder.state_dict()
+
+        for k,v in wav2lip_state_dict.items():
+            if 'audio_encoder' in k:
+                state_dict[k.replace('module.audio_encoder.', '')] = v
+        self.audio_encoder.load_state_dict(state_dict)
+
+
+    def forward(self, audio_sequences):
+        # audio_sequences = (B, T, 1, 80, 16)
+        B = audio_sequences.size(0)
+
+        audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
+
+        audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
+        dim = audio_embedding.shape[1]
+        audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
+
+        return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512 
diff --git a/src/audio2pose_models/cvae.py b/src/audio2pose_models/cvae.py
new file mode 100644
index 0000000000000000000000000000000000000000..407b78894cde564dd3f2819772a84e8bb1de251d
--- /dev/null
+++ b/src/audio2pose_models/cvae.py
@@ -0,0 +1,149 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+from src.audio2pose_models.res_unet import ResUnet
+
+def class2onehot(idx, class_num):
+
+    assert torch.max(idx).item() < class_num
+    onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
+    onehot.scatter_(1, idx, 1)
+    return onehot
+
+class CVAE(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
+        decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
+        latent_size = cfg.MODEL.CVAE.LATENT_SIZE
+        num_classes = cfg.DATASET.NUM_CLASSES
+        audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
+        audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
+        seq_len = cfg.MODEL.CVAE.SEQ_LEN
+
+        self.latent_size = latent_size
+
+        self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
+                                audio_emb_in_size, audio_emb_out_size, seq_len)
+        self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
+                                audio_emb_in_size, audio_emb_out_size, seq_len)
+    def reparameterize(self, mu, logvar):
+        std = torch.exp(0.5 * logvar)
+        eps = torch.randn_like(std)
+        return mu + eps * std
+
+    def forward(self, batch):
+        batch = self.encoder(batch)
+        mu = batch['mu']
+        logvar = batch['logvar']
+        z = self.reparameterize(mu, logvar)
+        batch['z'] = z
+        return self.decoder(batch)
+
+    def test(self, batch):
+        '''
+        class_id = batch['class']
+        z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
+        batch['z'] = z
+        '''
+        return self.decoder(batch)
+
+class ENCODER(nn.Module):
+    def __init__(self, layer_sizes, latent_size, num_classes, 
+                audio_emb_in_size, audio_emb_out_size, seq_len):
+        super().__init__()
+
+        self.resunet = ResUnet()
+        self.num_classes = num_classes
+        self.seq_len = seq_len
+
+        self.MLP = nn.Sequential()
+        layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
+        for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
+            self.MLP.add_module(
+                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+            self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+
+        self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
+        self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
+        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+    def forward(self, batch):
+        class_id = batch['class']
+        pose_motion_gt = batch['pose_motion_gt']                             #bs seq_len 6
+        ref = batch['ref']                             #bs 6
+        bs = pose_motion_gt.shape[0]
+        audio_in = batch['audio_emb']                          # bs seq_len audio_emb_in_size
+
+        #pose encode
+        pose_emb = self.resunet(pose_motion_gt.unsqueeze(1))          #bs 1 seq_len 6 
+        pose_emb = pose_emb.reshape(bs, -1)                    #bs seq_len*6
+
+        #audio mapping
+        print(audio_in.shape)
+        audio_out = self.linear_audio(audio_in)                # bs seq_len audio_emb_out_size
+        audio_out = audio_out.reshape(bs, -1)
+
+        class_bias = self.classbias[class_id]                  #bs latent_size
+        x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
+        x_out = self.MLP(x_in)
+
+        mu = self.linear_means(x_out)
+        logvar = self.linear_means(x_out)                      #bs latent_size 
+
+        batch.update({'mu':mu, 'logvar':logvar})
+        return batch
+
+class DECODER(nn.Module):
+    def __init__(self, layer_sizes, latent_size, num_classes, 
+                audio_emb_in_size, audio_emb_out_size, seq_len):
+        super().__init__()
+
+        self.resunet = ResUnet()
+        self.num_classes = num_classes
+        self.seq_len = seq_len
+
+        self.MLP = nn.Sequential()
+        input_size = latent_size + seq_len*audio_emb_out_size + 6
+        for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
+            self.MLP.add_module(
+                name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
+            if i+1 < len(layer_sizes):
+                self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
+            else:
+                self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
+        
+        self.pose_linear = nn.Linear(6, 6)
+        self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
+
+        self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
+
+    def forward(self, batch):
+
+        z = batch['z']                                          #bs latent_size
+        bs = z.shape[0]
+        class_id = batch['class']
+        ref = batch['ref']                             #bs 6
+        audio_in = batch['audio_emb']                           # bs seq_len audio_emb_in_size
+        #print('audio_in: ', audio_in[:, :, :10])
+
+        audio_out = self.linear_audio(audio_in)                 # bs seq_len audio_emb_out_size
+        #print('audio_out: ', audio_out[:, :, :10])
+        audio_out = audio_out.reshape([bs, -1])                 # bs seq_len*audio_emb_out_size
+        class_bias = self.classbias[class_id]                   #bs latent_size
+
+        z = z + class_bias
+        x_in = torch.cat([ref, z, audio_out], dim=-1)
+        x_out = self.MLP(x_in)                                  # bs layer_sizes[-1]
+        x_out = x_out.reshape((bs, self.seq_len, -1))
+
+        #print('x_out: ', x_out)
+
+        pose_emb = self.resunet(x_out.unsqueeze(1))             #bs 1 seq_len 6
+
+        pose_motion_pred = self.pose_linear(pose_emb.squeeze(1))       #bs seq_len 6
+
+        batch.update({'pose_motion_pred':pose_motion_pred})
+        return batch
diff --git a/src/audio2pose_models/discriminator.py b/src/audio2pose_models/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8ed6e36708d4a70227ff90109f56c6f73a17d2
--- /dev/null
+++ b/src/audio2pose_models/discriminator.py
@@ -0,0 +1,76 @@
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+class ConvNormRelu(nn.Module):
+    def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
+                 kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
+        super().__init__()
+        if kernel_size is None:
+            if downsample:
+                kernel_size, stride, padding = 4, 2, 1
+            else:
+                kernel_size, stride, padding = 3, 1, 1
+
+        if conv_type == '2d':
+            self.conv = nn.Conv2d(
+                in_channels,
+                out_channels,
+                kernel_size,
+                stride,
+                padding,
+                bias=False,
+            )
+            if norm == 'BN':
+                self.norm = nn.BatchNorm2d(out_channels)
+            elif norm == 'IN':
+                self.norm = nn.InstanceNorm2d(out_channels)
+            else:
+                raise NotImplementedError
+        elif conv_type == '1d':
+            self.conv = nn.Conv1d(
+                in_channels,
+                out_channels,
+                kernel_size,
+                stride,
+                padding,
+                bias=False,
+            )
+            if norm == 'BN':
+                self.norm = nn.BatchNorm1d(out_channels)
+            elif norm == 'IN':
+                self.norm = nn.InstanceNorm1d(out_channels)
+            else:
+                raise NotImplementedError
+        nn.init.kaiming_normal_(self.conv.weight)
+
+        self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
+
+    def forward(self, x):
+        x = self.conv(x)
+        if isinstance(self.norm, nn.InstanceNorm1d):
+            x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1))  # normalize on [C]
+        else:
+            x = self.norm(x)
+        x = self.act(x)
+        return x
+
+
+class PoseSequenceDiscriminator(nn.Module):
+    def __init__(self, cfg):
+        super().__init__()
+        self.cfg = cfg
+        leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
+
+        self.seq = nn.Sequential(
+            ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky),  # B, 256, 64
+            ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky),  # B, 512, 32
+            ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky),  # B, 1024, 16
+            nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True)  # B, 1, 16
+        )
+
+    def forward(self, x):
+        x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
+        x = self.seq(x)
+        x = x.squeeze(1)
+        return x
\ No newline at end of file
diff --git a/src/audio2pose_models/networks.py b/src/audio2pose_models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..9212b49836d9221895993d1d490a476707599922
--- /dev/null
+++ b/src/audio2pose_models/networks.py
@@ -0,0 +1,140 @@
+import torch.nn as nn
+import torch
+
+
+class ResidualConv(nn.Module):
+    def __init__(self, input_dim, output_dim, stride, padding):
+        super(ResidualConv, self).__init__()
+
+        self.conv_block = nn.Sequential(
+            nn.BatchNorm2d(input_dim),
+            nn.ReLU(),
+            nn.Conv2d(
+                input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
+            ),
+            nn.BatchNorm2d(output_dim),
+            nn.ReLU(),
+            nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
+        )
+        self.conv_skip = nn.Sequential(
+            nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
+            nn.BatchNorm2d(output_dim),
+        )
+
+    def forward(self, x):
+
+        return self.conv_block(x) + self.conv_skip(x)
+
+
+class Upsample(nn.Module):
+    def __init__(self, input_dim, output_dim, kernel, stride):
+        super(Upsample, self).__init__()
+
+        self.upsample = nn.ConvTranspose2d(
+            input_dim, output_dim, kernel_size=kernel, stride=stride
+        )
+
+    def forward(self, x):
+        return self.upsample(x)
+
+
+class Squeeze_Excite_Block(nn.Module):
+    def __init__(self, channel, reduction=16):
+        super(Squeeze_Excite_Block, self).__init__()
+        self.avg_pool = nn.AdaptiveAvgPool2d(1)
+        self.fc = nn.Sequential(
+            nn.Linear(channel, channel // reduction, bias=False),
+            nn.ReLU(inplace=True),
+            nn.Linear(channel // reduction, channel, bias=False),
+            nn.Sigmoid(),
+        )
+
+    def forward(self, x):
+        b, c, _, _ = x.size()
+        y = self.avg_pool(x).view(b, c)
+        y = self.fc(y).view(b, c, 1, 1)
+        return x * y.expand_as(x)
+
+
+class ASPP(nn.Module):
+    def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
+        super(ASPP, self).__init__()
+
+        self.aspp_block1 = nn.Sequential(
+            nn.Conv2d(
+                in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
+            ),
+            nn.ReLU(inplace=True),
+            nn.BatchNorm2d(out_dims),
+        )
+        self.aspp_block2 = nn.Sequential(
+            nn.Conv2d(
+                in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
+            ),
+            nn.ReLU(inplace=True),
+            nn.BatchNorm2d(out_dims),
+        )
+        self.aspp_block3 = nn.Sequential(
+            nn.Conv2d(
+                in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
+            ),
+            nn.ReLU(inplace=True),
+            nn.BatchNorm2d(out_dims),
+        )
+
+        self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
+        self._init_weights()
+
+    def forward(self, x):
+        x1 = self.aspp_block1(x)
+        x2 = self.aspp_block2(x)
+        x3 = self.aspp_block3(x)
+        out = torch.cat([x1, x2, x3], dim=1)
+        return self.output(out)
+
+    def _init_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight)
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+
+
+class Upsample_(nn.Module):
+    def __init__(self, scale=2):
+        super(Upsample_, self).__init__()
+
+        self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
+
+    def forward(self, x):
+        return self.upsample(x)
+
+
+class AttentionBlock(nn.Module):
+    def __init__(self, input_encoder, input_decoder, output_dim):
+        super(AttentionBlock, self).__init__()
+
+        self.conv_encoder = nn.Sequential(
+            nn.BatchNorm2d(input_encoder),
+            nn.ReLU(),
+            nn.Conv2d(input_encoder, output_dim, 3, padding=1),
+            nn.MaxPool2d(2, 2),
+        )
+
+        self.conv_decoder = nn.Sequential(
+            nn.BatchNorm2d(input_decoder),
+            nn.ReLU(),
+            nn.Conv2d(input_decoder, output_dim, 3, padding=1),
+        )
+
+        self.conv_attn = nn.Sequential(
+            nn.BatchNorm2d(output_dim),
+            nn.ReLU(),
+            nn.Conv2d(output_dim, 1, 1),
+        )
+
+    def forward(self, x1, x2):
+        out = self.conv_encoder(x1) + self.conv_decoder(x2)
+        out = self.conv_attn(out)
+        return out * x2
\ No newline at end of file
diff --git a/src/audio2pose_models/res_unet.py b/src/audio2pose_models/res_unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..280404c2a2804038705f792dd800ddf707b75cf8
--- /dev/null
+++ b/src/audio2pose_models/res_unet.py
@@ -0,0 +1,65 @@
+import torch
+import torch.nn as nn
+from src.audio2pose_models.networks import ResidualConv, Upsample
+
+
+class ResUnet(nn.Module):
+    def __init__(self, channel=1, filters=[32, 64, 128, 256]):
+        super(ResUnet, self).__init__()
+
+        self.input_layer = nn.Sequential(
+            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
+            nn.BatchNorm2d(filters[0]),
+            nn.ReLU(),
+            nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
+        )
+        self.input_skip = nn.Sequential(
+            nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
+        )
+
+        self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
+        self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
+
+        self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
+
+        self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
+        self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
+
+        self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
+        self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
+
+        self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
+        self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
+
+        self.output_layer = nn.Sequential(
+            nn.Conv2d(filters[0], 1, 1, 1),
+            nn.Sigmoid(),
+        )
+
+    def forward(self, x):
+        # Encode
+        x1 = self.input_layer(x) + self.input_skip(x)
+        x2 = self.residual_conv_1(x1)
+        x3 = self.residual_conv_2(x2)
+        # Bridge
+        x4 = self.bridge(x3)
+
+        # Decode
+        x4 = self.upsample_1(x4)
+        x5 = torch.cat([x4, x3], dim=1)
+
+        x6 = self.up_residual_conv1(x5)
+
+        x6 = self.upsample_2(x6)
+        x7 = torch.cat([x6, x2], dim=1)
+
+        x8 = self.up_residual_conv2(x7)
+
+        x8 = self.upsample_3(x8)
+        x9 = torch.cat([x8, x1], dim=1)
+
+        x10 = self.up_residual_conv3(x9)
+
+        output = self.output_layer(x10)
+
+        return output
\ No newline at end of file
diff --git a/src/config/auido2exp.yaml b/src/config/auido2exp.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7e0e8fbba267158d26a147c8cb2ec5acdd73f432
--- /dev/null
+++ b/src/config/auido2exp.yaml
@@ -0,0 +1,58 @@
+DATASET:
+  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
+  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
+  TRAIN_BATCH_SIZE: 32
+  EVAL_BATCH_SIZE: 32
+  EXP: True
+  EXP_DIM: 64
+  FRAME_LEN: 32
+  COEFF_LEN: 73
+  NUM_CLASSES: 46
+  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
+  LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+  DEBUG: True
+  NUM_REPEATS: 2
+  T: 40
+  
+
+MODEL:
+  FRAMEWORK: V2
+  AUDIOENCODER:
+    LEAKY_RELU: True
+    NORM: 'IN'
+  DISCRIMINATOR:
+    LEAKY_RELU: False
+    INPUT_CHANNELS: 6
+  CVAE:
+    AUDIO_EMB_IN_SIZE: 512
+    AUDIO_EMB_OUT_SIZE: 128
+    SEQ_LEN: 32
+    LATENT_SIZE: 256
+    ENCODER_LAYER_SIZES: [192, 1024]
+    DECODER_LAYER_SIZES: [1024, 192]
+    
+
+TRAIN:
+  MAX_EPOCH: 300
+  GENERATOR:
+    LR: 2.0e-5
+  DISCRIMINATOR:
+    LR: 1.0e-5
+  LOSS:
+    W_FEAT: 0
+    W_COEFF_EXP: 2
+    W_LM: 1.0e-2
+    W_LM_MOUTH: 0
+    W_REG: 0
+    W_SYNC: 0
+    W_COLOR: 0
+    W_EXPRESSION: 0
+    W_LIPREADING: 0.01
+    W_LIPREADING_VV: 0
+    W_EYE_BLINK: 4
+
+TAG:
+  NAME:  small_dataset
+
+
diff --git a/src/config/auido2pose.yaml b/src/config/auido2pose.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..7702414b11581ff99aef7a3187f0d0d1388ae3f3
--- /dev/null
+++ b/src/config/auido2pose.yaml
@@ -0,0 +1,49 @@
+DATASET:
+  TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
+  EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
+  TRAIN_BATCH_SIZE: 64
+  EVAL_BATCH_SIZE: 1
+  EXP: True
+  EXP_DIM: 64
+  FRAME_LEN: 32
+  COEFF_LEN: 73
+  NUM_CLASSES: 46
+  AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
+  COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
+  DEBUG: True
+  
+
+MODEL:
+  AUDIOENCODER:
+    LEAKY_RELU: True
+    NORM: 'IN'
+  DISCRIMINATOR:
+    LEAKY_RELU: False
+    INPUT_CHANNELS: 6
+  CVAE:
+    AUDIO_EMB_IN_SIZE: 512
+    AUDIO_EMB_OUT_SIZE: 6
+    SEQ_LEN: 32
+    LATENT_SIZE: 64
+    ENCODER_LAYER_SIZES: [192, 128]
+    DECODER_LAYER_SIZES: [128, 192]
+    
+
+TRAIN:
+  MAX_EPOCH: 150
+  GENERATOR:
+    LR: 1.0e-4
+  DISCRIMINATOR:
+    LR: 1.0e-4
+  LOSS:
+    LAMBDA_REG: 1
+    LAMBDA_LANDMARKS: 0
+    LAMBDA_VERTICES: 0
+    LAMBDA_GAN_MOTION: 0.7
+    LAMBDA_GAN_COEFF: 0
+    LAMBDA_KL: 1
+
+TAG:
+  NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
+
+
diff --git a/src/config/facerender.yaml b/src/config/facerender.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..dd1e1ddfe265698e49dac4a6e103cba0aac4f3ce
--- /dev/null
+++ b/src/config/facerender.yaml
@@ -0,0 +1,45 @@
+model_params:
+  common_params:
+    num_kp: 15 
+    image_channel: 3                    
+    feature_channel: 32
+    estimate_jacobian: False   # True
+  kp_detector_params:
+     temperature: 0.1
+     block_expansion: 32            
+     max_features: 1024
+     scale_factor: 0.25         # 0.25
+     num_blocks: 5
+     reshape_channel: 16384  # 16384 = 1024 * 16
+     reshape_depth: 16
+  he_estimator_params:
+     block_expansion: 64            
+     max_features: 2048
+     num_bins: 66
+  generator_params:
+    block_expansion: 64
+    max_features: 512
+    num_down_blocks: 2
+    reshape_channel: 32
+    reshape_depth: 16         # 512 = 32 * 16
+    num_resblocks: 6
+    estimate_occlusion_map: True
+    dense_motion_params:
+      block_expansion: 32
+      max_features: 1024
+      num_blocks: 5
+      reshape_depth: 16
+      compress: 4
+  discriminator_params:
+    scales: [1]
+    block_expansion: 32                 
+    max_features: 512
+    num_blocks: 4
+    sn: True
+  mapping_params:
+      coeff_nc: 70
+      descriptor_nc: 1024
+      layer: 3
+      num_kp: 15
+      num_bins: 66
+
diff --git a/src/config/facerender_still.yaml b/src/config/facerender_still.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..d6b84181763caf7184a0769e53a7e419e2e3f604
--- /dev/null
+++ b/src/config/facerender_still.yaml
@@ -0,0 +1,45 @@
+model_params:
+  common_params:
+    num_kp: 15 
+    image_channel: 3                    
+    feature_channel: 32
+    estimate_jacobian: False   # True
+  kp_detector_params:
+     temperature: 0.1
+     block_expansion: 32            
+     max_features: 1024
+     scale_factor: 0.25         # 0.25
+     num_blocks: 5
+     reshape_channel: 16384  # 16384 = 1024 * 16
+     reshape_depth: 16
+  he_estimator_params:
+     block_expansion: 64            
+     max_features: 2048
+     num_bins: 66
+  generator_params:
+    block_expansion: 64
+    max_features: 512
+    num_down_blocks: 2
+    reshape_channel: 32
+    reshape_depth: 16         # 512 = 32 * 16
+    num_resblocks: 6
+    estimate_occlusion_map: True
+    dense_motion_params:
+      block_expansion: 32
+      max_features: 1024
+      num_blocks: 5
+      reshape_depth: 16
+      compress: 4
+  discriminator_params:
+    scales: [1]
+    block_expansion: 32                 
+    max_features: 512
+    num_blocks: 4
+    sn: True
+  mapping_params:
+      coeff_nc: 73
+      descriptor_nc: 1024
+      layer: 3
+      num_kp: 15
+      num_bins: 66
+
diff --git a/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc b/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b907c972bbd7e781cfbd61cabdcb076756d887a3
Binary files /dev/null and b/src/face3d/__pycache__/extract_kp_videos.cpython-38.pyc differ
diff --git a/src/face3d/__pycache__/extract_kp_videos.cpython-39.pyc b/src/face3d/__pycache__/extract_kp_videos.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d13ad86aa30bb74014be5fc229183355b9c2e97
Binary files /dev/null and b/src/face3d/__pycache__/extract_kp_videos.cpython-39.pyc differ
diff --git a/src/face3d/data/__init__.py b/src/face3d/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be2378c5877af8e749db18d8a67a382f3eb0912b
--- /dev/null
+++ b/src/face3d/data/__init__.py
@@ -0,0 +1,116 @@
+"""This package includes all the modules related to data loading and preprocessing
+
+ To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
+ You need to implement four functions:
+    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
+    -- <__len__>:                       return the size of dataset.
+    -- <__getitem__>:                   get a data point from data loader.
+    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
+
+Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
+See our template dataset class 'template_dataset.py' for more details.
+"""
+import numpy as np
+import importlib
+import torch.utils.data
+from face3d.data.base_dataset import BaseDataset
+
+
+def find_dataset_using_name(dataset_name):
+    """Import the module "data/[dataset_name]_dataset.py".
+
+    In the file, the class called DatasetNameDataset() will
+    be instantiated. It has to be a subclass of BaseDataset,
+    and it is case-insensitive.
+    """
+    dataset_filename = "data." + dataset_name + "_dataset"
+    datasetlib = importlib.import_module(dataset_filename)
+
+    dataset = None
+    target_dataset_name = dataset_name.replace('_', '') + 'dataset'
+    for name, cls in datasetlib.__dict__.items():
+        if name.lower() == target_dataset_name.lower() \
+           and issubclass(cls, BaseDataset):
+            dataset = cls
+
+    if dataset is None:
+        raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
+
+    return dataset
+
+
+def get_option_setter(dataset_name):
+    """Return the static method <modify_commandline_options> of the dataset class."""
+    dataset_class = find_dataset_using_name(dataset_name)
+    return dataset_class.modify_commandline_options
+
+
+def create_dataset(opt, rank=0):
+    """Create a dataset given the option.
+
+    This function wraps the class CustomDatasetDataLoader.
+        This is the main interface between this package and 'train.py'/'test.py'
+
+    Example:
+        >>> from data import create_dataset
+        >>> dataset = create_dataset(opt)
+    """
+    data_loader = CustomDatasetDataLoader(opt, rank=rank)
+    dataset = data_loader.load_data()
+    return dataset
+
+class CustomDatasetDataLoader():
+    """Wrapper class of Dataset class that performs multi-threaded data loading"""
+
+    def __init__(self, opt, rank=0):
+        """Initialize this class
+
+        Step 1: create a dataset instance given the name [dataset_mode]
+        Step 2: create a multi-threaded data loader.
+        """
+        self.opt = opt
+        dataset_class = find_dataset_using_name(opt.dataset_mode)
+        self.dataset = dataset_class(opt)
+        self.sampler = None
+        print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
+        if opt.use_ddp and opt.isTrain:
+            world_size = opt.world_size
+            self.sampler = torch.utils.data.distributed.DistributedSampler(
+                    self.dataset,
+                    num_replicas=world_size,
+                    rank=rank,
+                    shuffle=not opt.serial_batches
+                )
+            self.dataloader = torch.utils.data.DataLoader(
+                        self.dataset,
+                        sampler=self.sampler,
+                        num_workers=int(opt.num_threads / world_size), 
+                        batch_size=int(opt.batch_size / world_size), 
+                        drop_last=True)
+        else:
+            self.dataloader = torch.utils.data.DataLoader(
+                self.dataset,
+                batch_size=opt.batch_size,
+                shuffle=(not opt.serial_batches) and opt.isTrain,
+                num_workers=int(opt.num_threads),
+                drop_last=True
+            )
+
+    def set_epoch(self, epoch):
+        self.dataset.current_epoch = epoch
+        if self.sampler is not None:
+            self.sampler.set_epoch(epoch)
+
+    def load_data(self):
+        return self
+
+    def __len__(self):
+        """Return the number of data in the dataset"""
+        return min(len(self.dataset), self.opt.max_dataset_size)
+
+    def __iter__(self):
+        """Return a batch of data"""
+        for i, data in enumerate(self.dataloader):
+            if i * self.opt.batch_size >= self.opt.max_dataset_size:
+                break
+            yield data
diff --git a/src/face3d/data/base_dataset.py b/src/face3d/data/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a7ea5024206e6e58c2f404ac6a1bf0987f5fd4
--- /dev/null
+++ b/src/face3d/data/base_dataset.py
@@ -0,0 +1,125 @@
+"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
+
+It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
+"""
+import random
+import numpy as np
+import torch.utils.data as data
+from PIL import Image
+import torchvision.transforms as transforms
+from abc import ABC, abstractmethod
+
+
+class BaseDataset(data.Dataset, ABC):
+    """This class is an abstract base class (ABC) for datasets.
+
+    To create a subclass, you need to implement the following four functions:
+    -- <__init__>:                      initialize the class, first call BaseDataset.__init__(self, opt).
+    -- <__len__>:                       return the size of dataset.
+    -- <__getitem__>:                   get a data point.
+    -- <modify_commandline_options>:    (optionally) add dataset-specific options and set default options.
+    """
+
+    def __init__(self, opt):
+        """Initialize the class; save the options in the class
+
+        Parameters:
+            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+        """
+        self.opt = opt
+        # self.root = opt.dataroot
+        self.current_epoch = 0
+
+    @staticmethod
+    def modify_commandline_options(parser, is_train):
+        """Add new dataset-specific options, and rewrite default values for existing options.
+
+        Parameters:
+            parser          -- original option parser
+            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+        Returns:
+            the modified parser.
+        """
+        return parser
+
+    @abstractmethod
+    def __len__(self):
+        """Return the total number of images in the dataset."""
+        return 0
+
+    @abstractmethod
+    def __getitem__(self, index):
+        """Return a data point and its metadata information.
+
+        Parameters:
+            index - - a random integer for data indexing
+
+        Returns:
+            a dictionary of data with their names. It ususally contains the data itself and its metadata information.
+        """
+        pass
+
+
+def get_transform(grayscale=False):
+    transform_list = []
+    if grayscale:
+        transform_list.append(transforms.Grayscale(1))
+    transform_list += [transforms.ToTensor()]
+    return transforms.Compose(transform_list)
+
+def get_affine_mat(opt, size):
+    shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
+    w, h = size
+
+    if 'shift' in opt.preprocess:
+        shift_pixs = int(opt.shift_pixs)
+        shift_x = random.randint(-shift_pixs, shift_pixs)
+        shift_y = random.randint(-shift_pixs, shift_pixs)
+    if 'scale' in opt.preprocess:
+        scale = 1 + opt.scale_delta * (2 * random.random() - 1)
+    if 'rot' in opt.preprocess:
+        rot_angle = opt.rot_angle * (2 * random.random() - 1)
+        rot_rad = -rot_angle * np.pi/180
+    if 'flip' in opt.preprocess:
+        flip = random.random() > 0.5
+
+    shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
+    flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
+    shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
+    rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
+    scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
+    shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
+    
+    affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin    
+    affine_inv = np.linalg.inv(affine)
+    return affine, affine_inv, flip
+
+def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
+    return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
+
+def apply_lm_affine(landmark, affine, flip, size):
+    _, h = size
+    lm = landmark.copy()
+    lm[:, 1] = h - 1 - lm[:, 1]
+    lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
+    lm = lm @ np.transpose(affine)
+    lm[:, :2] = lm[:, :2] / lm[:, 2:]
+    lm = lm[:, :2]
+    lm[:, 1] = h - 1 - lm[:, 1]
+    if flip:
+        lm_ = lm.copy()
+        lm_[:17] = lm[16::-1]
+        lm_[17:22] = lm[26:21:-1]
+        lm_[22:27] = lm[21:16:-1]
+        lm_[31:36] = lm[35:30:-1]
+        lm_[36:40] = lm[45:41:-1]
+        lm_[40:42] = lm[47:45:-1]
+        lm_[42:46] = lm[39:35:-1]
+        lm_[46:48] = lm[41:39:-1]
+        lm_[48:55] = lm[54:47:-1]
+        lm_[55:60] = lm[59:54:-1]
+        lm_[60:65] = lm[64:59:-1]
+        lm_[65:68] = lm[67:64:-1]
+        lm = lm_
+    return lm
diff --git a/src/face3d/data/flist_dataset.py b/src/face3d/data/flist_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..63b49caa8020f8e9aedb73a839b7112320cad68a
--- /dev/null
+++ b/src/face3d/data/flist_dataset.py
@@ -0,0 +1,125 @@
+"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
+"""
+
+import os.path
+from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
+from data.image_folder import make_dataset
+from PIL import Image
+import random
+import util.util as util
+import numpy as np
+import json
+import torch
+from scipy.io import loadmat, savemat
+import pickle
+from util.preprocess import align_img, estimate_norm
+from util.load_mats import load_lm3d
+
+
+def default_flist_reader(flist):
+    """
+    flist format: impath label\nimpath label\n ...(same to caffe's filelist)
+    """
+    imlist = []
+    with open(flist, 'r') as rf:
+        for line in rf.readlines():
+            impath = line.strip()
+            imlist.append(impath)
+
+    return imlist
+
+def jason_flist_reader(flist):
+    with open(flist, 'r') as fp:
+        info = json.load(fp)
+    return info
+
+def parse_label(label):
+    return torch.tensor(np.array(label).astype(np.float32))
+
+
+class FlistDataset(BaseDataset):
+    """
+    It requires one directories to host training images '/path/to/data/train'
+    You can train the model with the dataset flag '--dataroot /path/to/data'.
+    """
+
+    def __init__(self, opt):
+        """Initialize this dataset class.
+
+        Parameters:
+            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+        """
+        BaseDataset.__init__(self, opt)
+        
+        self.lm3d_std = load_lm3d(opt.bfm_folder)
+        
+        msk_names = default_flist_reader(opt.flist)
+        self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
+
+        self.size = len(self.msk_paths) 
+        self.opt = opt
+        
+        self.name = 'train' if opt.isTrain else 'val'
+        if '_' in opt.flist:
+            self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
+        
+
+    def __getitem__(self, index):
+        """Return a data point and its metadata information.
+
+        Parameters:
+            index (int)      -- a random integer for data indexing
+
+        Returns a dictionary that contains A, B, A_paths and B_paths
+            img (tensor)       -- an image in the input domain
+            msk (tensor)       -- its corresponding attention mask
+            lm  (tensor)       -- its corresponding 3d landmarks
+            im_paths (str)     -- image paths
+            aug_flag (bool)    -- a flag used to tell whether its raw or augmented
+        """
+        msk_path = self.msk_paths[index % self.size]  # make sure index is within then range
+        img_path = msk_path.replace('mask/', '')
+        lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
+
+        raw_img = Image.open(img_path).convert('RGB')
+        raw_msk = Image.open(msk_path).convert('RGB')
+        raw_lm = np.loadtxt(lm_path).astype(np.float32)
+
+        _, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
+        
+        aug_flag = self.opt.use_aug and self.opt.isTrain
+        if aug_flag:
+            img, lm, msk = self._augmentation(img, lm, self.opt, msk)
+        
+        _, H = img.size
+        M = estimate_norm(lm, H)
+        transform = get_transform()
+        img_tensor = transform(img)
+        msk_tensor = transform(msk)[:1, ...]
+        lm_tensor = parse_label(lm)
+        M_tensor = parse_label(M)
+
+
+        return {'imgs': img_tensor, 
+                'lms': lm_tensor, 
+                'msks': msk_tensor, 
+                'M': M_tensor,
+                'im_paths': img_path, 
+                'aug_flag': aug_flag,
+                'dataset': self.name}
+
+    def _augmentation(self, img, lm, opt, msk=None):
+        affine, affine_inv, flip = get_affine_mat(opt, img.size)
+        img = apply_img_affine(img, affine_inv)
+        lm = apply_lm_affine(lm, affine, flip, img.size)
+        if msk is not None:
+            msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
+        return img, lm, msk
+    
+
+
+
+    def __len__(self):
+        """Return the total number of images in the dataset.
+        """
+        return self.size
diff --git a/src/face3d/data/image_folder.py b/src/face3d/data/image_folder.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ef069029b0db1fc40b9b5f9a6f52a48c1cd162
--- /dev/null
+++ b/src/face3d/data/image_folder.py
@@ -0,0 +1,66 @@
+"""A modified image folder class
+
+We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
+so that this class can load images from both current directory and its subdirectories.
+"""
+import numpy as np
+import torch.utils.data as data
+
+from PIL import Image
+import os
+import os.path
+
+IMG_EXTENSIONS = [
+    '.jpg', '.JPG', '.jpeg', '.JPEG',
+    '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
+    '.tif', '.TIF', '.tiff', '.TIFF',
+]
+
+
+def is_image_file(filename):
+    return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
+
+
+def make_dataset(dir, max_dataset_size=float("inf")):
+    images = []
+    assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
+
+    for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
+        for fname in fnames:
+            if is_image_file(fname):
+                path = os.path.join(root, fname)
+                images.append(path)
+    return images[:min(max_dataset_size, len(images))]
+
+
+def default_loader(path):
+    return Image.open(path).convert('RGB')
+
+
+class ImageFolder(data.Dataset):
+
+    def __init__(self, root, transform=None, return_paths=False,
+                 loader=default_loader):
+        imgs = make_dataset(root)
+        if len(imgs) == 0:
+            raise(RuntimeError("Found 0 images in: " + root + "\n"
+                               "Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
+
+        self.root = root
+        self.imgs = imgs
+        self.transform = transform
+        self.return_paths = return_paths
+        self.loader = loader
+
+    def __getitem__(self, index):
+        path = self.imgs[index]
+        img = self.loader(path)
+        if self.transform is not None:
+            img = self.transform(img)
+        if self.return_paths:
+            return img, path
+        else:
+            return img
+
+    def __len__(self):
+        return len(self.imgs)
diff --git a/src/face3d/data/template_dataset.py b/src/face3d/data/template_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..693b6b09085ad424e53f26e0938b61eea30ed644
--- /dev/null
+++ b/src/face3d/data/template_dataset.py
@@ -0,0 +1,75 @@
+"""Dataset class template
+
+This module provides a template for users to implement custom datasets.
+You can specify '--dataset_mode template' to use this dataset.
+The class name should be consistent with both the filename and its dataset_mode option.
+The filename should be <dataset_mode>_dataset.py
+The class name should be <Dataset_mode>Dataset.py
+You need to implement the following functions:
+    -- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
+    -- <__init__>: Initialize this dataset class.
+    -- <__getitem__>: Return a data point and its metadata information.
+    -- <__len__>: Return the number of images.
+"""
+from data.base_dataset import BaseDataset, get_transform
+# from data.image_folder import make_dataset
+# from PIL import Image
+
+
+class TemplateDataset(BaseDataset):
+    """A template dataset class for you to implement custom datasets."""
+    @staticmethod
+    def modify_commandline_options(parser, is_train):
+        """Add new dataset-specific options, and rewrite default values for existing options.
+
+        Parameters:
+            parser          -- original option parser
+            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+        Returns:
+            the modified parser.
+        """
+        parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
+        parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0)  # specify dataset-specific default values
+        return parser
+
+    def __init__(self, opt):
+        """Initialize this dataset class.
+
+        Parameters:
+            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+        A few things can be done here.
+        - save the options (have been done in BaseDataset)
+        - get image paths and meta information of the dataset.
+        - define the image transformation.
+        """
+        # save the option and dataset root
+        BaseDataset.__init__(self, opt)
+        # get the image paths of your dataset;
+        self.image_paths = []  # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
+        # define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
+        self.transform = get_transform(opt)
+
+    def __getitem__(self, index):
+        """Return a data point and its metadata information.
+
+        Parameters:
+            index -- a random integer for data indexing
+
+        Returns:
+            a dictionary of data with their names. It usually contains the data itself and its metadata information.
+
+        Step 1: get a random image path: e.g., path = self.image_paths[index]
+        Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
+        Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
+        Step 4: return a data point as a dictionary.
+        """
+        path = 'temp'    # needs to be a string
+        data_A = None    # needs to be a tensor
+        data_B = None    # needs to be a tensor
+        return {'data_A': data_A, 'data_B': data_B, 'path': path}
+
+    def __len__(self):
+        """Return the total number of images."""
+        return len(self.image_paths)
diff --git a/src/face3d/extract_kp_videos.py b/src/face3d/extract_kp_videos.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12b55b877e7ee3ef6491200b4b9aced58260924
--- /dev/null
+++ b/src/face3d/extract_kp_videos.py
@@ -0,0 +1,108 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import face_alignment
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from itertools import cycle
+
+from torch.multiprocessing import Pool, Process, set_start_method
+
+class KeypointExtractor():
+    def __init__(self, device):
+        self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType.TWO_D, 
+                                                     device=device)   
+
+    def extract_keypoint(self, images, name=None, info=True):
+        if isinstance(images, list):
+            keypoints = []
+            if info:
+                i_range = tqdm(images,desc='landmark Det:')
+            else:
+                i_range = images
+
+            for image in i_range:
+                current_kp = self.extract_keypoint(image)
+                if np.mean(current_kp) == -1 and keypoints:
+                    keypoints.append(keypoints[-1])
+                else:
+                    keypoints.append(current_kp[None])
+
+            keypoints = np.concatenate(keypoints, 0)
+            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+            return keypoints
+        else:
+            while True:
+                try:
+                    keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
+                    break
+                except RuntimeError as e:
+                    if str(e).startswith('CUDA'):
+                        print("Warning: out of memory, sleep for 1s")
+                        time.sleep(1)
+                    else:
+                        print(e)
+                        break    
+                except TypeError:
+                    print('No face detected in this image')
+                    shape = [68, 2]
+                    keypoints = -1. * np.ones(shape)                    
+                    break
+            if name is not None:
+                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+            return keypoints
+
+def read_video(filename):
+    frames = []
+    cap = cv2.VideoCapture(filename)
+    while cap.isOpened():
+        ret, frame = cap.read()
+        if ret:
+            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            frame = Image.fromarray(frame)
+            frames.append(frame)
+        else:
+            break
+    cap.release()
+    return frames
+
+def run(data):
+    filename, opt, device = data
+    os.environ['CUDA_VISIBLE_DEVICES'] = device
+    kp_extractor = KeypointExtractor()
+    images = read_video(filename)
+    name = filename.split('/')[-2:]
+    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+    kp_extractor.extract_keypoint(
+        images, 
+        name=os.path.join(opt.output_dir, name[-2], name[-1])
+    )
+
+if __name__ == '__main__':
+    set_start_method('spawn')
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+    parser.add_argument('--device_ids', type=str, default='0,1')
+    parser.add_argument('--workers', type=int, default=4)
+
+    opt = parser.parse_args()
+    filenames = list()
+    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+    extensions = VIDEO_EXTENSIONS
+    
+    for ext in extensions:
+        os.listdir(f'{opt.input_dir}')
+        print(f'{opt.input_dir}/*.{ext}')
+        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+    print('Total number of videos:', len(filenames))
+    pool = Pool(opt.workers)
+    args_list = cycle([opt])
+    device_ids = opt.device_ids.split(",")
+    device_ids = cycle(device_ids)
+    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+        None
diff --git a/src/face3d/extract_kp_videos_safe.py b/src/face3d/extract_kp_videos_safe.py
new file mode 100644
index 0000000000000000000000000000000000000000..2619be14d9d4fa1b6c35062d1038ee2f49667c4f
--- /dev/null
+++ b/src/face3d/extract_kp_videos_safe.py
@@ -0,0 +1,138 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import numpy as np
+from PIL import Image
+import torch
+from tqdm import tqdm
+from itertools import cycle
+from facexlib.alignment import init_alignment_model, landmark_98_to_68
+from facexlib.detection import init_detection_model
+from torch.multiprocessing import Pool, Process, set_start_method
+
+
+class KeypointExtractor():
+    def __init__(self, device='cuda'):
+
+        ### gfpgan/weights
+        try:
+            import webui  # in webui
+            root_path = 'extensions/SadTalker/gfpgan/weights' 
+
+        except:
+            root_path = 'gfpgan/weights'
+
+        self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)   
+        self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
+
+    def extract_keypoint(self, images, name=None, info=True):
+        if isinstance(images, list):
+            keypoints = []
+            if info:
+                i_range = tqdm(images,desc='landmark Det:')
+            else:
+                i_range = images
+
+            for image in i_range:
+                current_kp = self.extract_keypoint(image)
+                # current_kp = self.detector.get_landmarks(np.array(image))
+                if np.mean(current_kp) == -1 and keypoints:
+                    keypoints.append(keypoints[-1])
+                else:
+                    keypoints.append(current_kp[None])
+
+            keypoints = np.concatenate(keypoints, 0)
+            np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+            return keypoints
+        else:
+            while True:
+                try:
+                    with torch.no_grad():
+                        # face detection -> face alignment.
+                        img = np.array(images)
+                        bboxes = self.det_net.detect_faces(images, 0.97)
+                        
+                        bboxes = bboxes[0]
+
+                        # bboxes[0] -= 100
+                        # bboxes[1] -= 100
+                        # bboxes[2] += 100
+                        # bboxes[3] += 100
+                        img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
+
+                        keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
+
+                        #### keypoints to the original location
+                        keypoints[:,0] += int(bboxes[0])
+                        keypoints[:,1] += int(bboxes[1])
+
+                        break
+                except RuntimeError as e:
+                    if str(e).startswith('CUDA'):
+                        print("Warning: out of memory, sleep for 1s")
+                        time.sleep(1)
+                    else:
+                        print(e)
+                        break    
+                except TypeError:
+                    print('No face detected in this image')
+                    shape = [68, 2]
+                    keypoints = -1. * np.ones(shape)                    
+                    break
+            if name is not None:
+                np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
+            return keypoints
+
+def read_video(filename):
+    frames = []
+    cap = cv2.VideoCapture(filename)
+    while cap.isOpened():
+        ret, frame = cap.read()
+        if ret:
+            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
+            frame = Image.fromarray(frame)
+            frames.append(frame)
+        else:
+            break
+    cap.release()
+    return frames
+
+def run(data):
+    filename, opt, device = data
+    os.environ['CUDA_VISIBLE_DEVICES'] = device
+    kp_extractor = KeypointExtractor()
+    images = read_video(filename)
+    name = filename.split('/')[-2:]
+    os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
+    kp_extractor.extract_keypoint(
+        images, 
+        name=os.path.join(opt.output_dir, name[-2], name[-1])
+    )
+
+if __name__ == '__main__':
+    set_start_method('spawn')
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+    parser.add_argument('--device_ids', type=str, default='0,1')
+    parser.add_argument('--workers', type=int, default=4)
+
+    opt = parser.parse_args()
+    filenames = list()
+    VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+    VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+    extensions = VIDEO_EXTENSIONS
+    
+    for ext in extensions:
+        os.listdir(f'{opt.input_dir}')
+        print(f'{opt.input_dir}/*.{ext}')
+        filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
+    print('Total number of videos:', len(filenames))
+    pool = Pool(opt.workers)
+    args_list = cycle([opt])
+    device_ids = opt.device_ids.split(",")
+    device_ids = cycle(device_ids)
+    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+        None
diff --git a/src/face3d/models/__init__.py b/src/face3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef6b5e399254bd42850f3385878f35d4acf90852
--- /dev/null
+++ b/src/face3d/models/__init__.py
@@ -0,0 +1,67 @@
+"""This package contains modules related to objective functions, optimizations, and network architectures.
+
+To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
+You need to implement the following five functions:
+    -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
+    -- <set_input>:                     unpack data from dataset and apply preprocessing.
+    -- <forward>:                       produce intermediate results.
+    -- <optimize_parameters>:           calculate loss, gradients, and update network weights.
+    -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
+
+In the function <__init__>, you need to define four lists:
+    -- self.loss_names (str list):          specify the training losses that you want to plot and save.
+    -- self.model_names (str list):         define networks used in our training.
+    -- self.visual_names (str list):        specify the images that you want to display and save.
+    -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
+
+Now you can use the model class by specifying flag '--model dummy'.
+See our template model class 'template_model.py' for more details.
+"""
+
+import importlib
+from src.face3d.models.base_model import BaseModel
+
+
+def find_model_using_name(model_name):
+    """Import the module "models/[model_name]_model.py".
+
+    In the file, the class called DatasetNameModel() will
+    be instantiated. It has to be a subclass of BaseModel,
+    and it is case-insensitive.
+    """
+    model_filename = "face3d.models." + model_name + "_model"
+    modellib = importlib.import_module(model_filename)
+    model = None
+    target_model_name = model_name.replace('_', '') + 'model'
+    for name, cls in modellib.__dict__.items():
+        if name.lower() == target_model_name.lower() \
+           and issubclass(cls, BaseModel):
+            model = cls
+
+    if model is None:
+        print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
+        exit(0)
+
+    return model
+
+
+def get_option_setter(model_name):
+    """Return the static method <modify_commandline_options> of the model class."""
+    model_class = find_model_using_name(model_name)
+    return model_class.modify_commandline_options
+
+
+def create_model(opt):
+    """Create a model given the option.
+
+    This function warps the class CustomDatasetDataLoader.
+    This is the main interface between this package and 'train.py'/'test.py'
+
+    Example:
+        >>> from models import create_model
+        >>> model = create_model(opt)
+    """
+    model = find_model_using_name(opt.model)
+    instance = model(opt)
+    print("model [%s] was created" % type(instance).__name__)
+    return instance
diff --git a/src/face3d/models/__pycache__/__init__.cpython-38.pyc b/src/face3d/models/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..1a32e76ab0d233478046b3184acb3e9c9acc2728
Binary files /dev/null and b/src/face3d/models/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/face3d/models/__pycache__/__init__.cpython-39.pyc b/src/face3d/models/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f82c186fe75aee5051e10cae8154f19919a6ea35
Binary files /dev/null and b/src/face3d/models/__pycache__/__init__.cpython-39.pyc differ
diff --git a/src/face3d/models/__pycache__/base_model.cpython-38.pyc b/src/face3d/models/__pycache__/base_model.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d9921f21b1365b1243d0032815ae21438d765ade
Binary files /dev/null and b/src/face3d/models/__pycache__/base_model.cpython-38.pyc differ
diff --git a/src/face3d/models/__pycache__/base_model.cpython-39.pyc b/src/face3d/models/__pycache__/base_model.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..67c46a9ef196770c70cbcf18f3287f31d4a0f86a
Binary files /dev/null and b/src/face3d/models/__pycache__/base_model.cpython-39.pyc differ
diff --git a/src/face3d/models/__pycache__/networks.cpython-38.pyc b/src/face3d/models/__pycache__/networks.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d86eaae39301dac5e9ea110e7f0f0122ce549eb0
Binary files /dev/null and b/src/face3d/models/__pycache__/networks.cpython-38.pyc differ
diff --git a/src/face3d/models/__pycache__/networks.cpython-39.pyc b/src/face3d/models/__pycache__/networks.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..86d85cab2e8fc79e9ce4cbbff7587a57c101cba0
Binary files /dev/null and b/src/face3d/models/__pycache__/networks.cpython-39.pyc differ
diff --git a/src/face3d/models/arcface_torch/README.md b/src/face3d/models/arcface_torch/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cc7f1d45f2f5e4b752c42dc81d3e2879c1459c6e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/README.md
@@ -0,0 +1,164 @@
+# Distributed Arcface Training in Pytorch
+
+This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
+identity on a single server.
+
+## Requirements
+
+- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
+- `pip install -r requirements.txt`.
+- Download the dataset
+  from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
+  .
+
+## How to Training
+
+To train a model, run `train.py` with the path to the configs:
+
+### 1. Single node, 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+```
+
+### 2. Multiple nodes, each node 8 GPUs:
+
+Node 0:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+Node 1:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
+```
+
+### 3.Training resnet2060 with 8 GPUs:
+
+```shell
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
+```
+
+## Model Zoo
+
+- The models are available for non-commercial research purposes only.  
+- All models can be found in here.  
+- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g):   e8pw  
+- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
+
+### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
+
+ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face 
+recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities. 
+As the result, we can evaluate the FAIR performance for different algorithms.  
+
+For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The 
+globalised multi-racial testset contains 242,143 identities and 1,624,305 images. 
+
+For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4). 
+Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images. 
+There are totally 13,928 positive pairs and 96,983,824 negative pairs.
+
+| Datasets | backbone  | Training throughout | Size / MB  | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
+| :---:    | :---      | :---                | :---       |:---                   |:---                  |     
+| MS1MV3    | r18  | -              | 91   | **47.85** | **68.33** |
+| Glint360k | r18  | 8536           | 91   | **53.32** | **72.07** |
+| MS1MV3    | r34  | -              | 130  | **58.72** | **77.36** |
+| Glint360k | r34  | 6344           | 130  | **65.10** | **83.02** |
+| MS1MV3    | r50  | 5500           | 166  | **63.85** | **80.53** |
+| Glint360k | r50  | 5136           | 166  | **70.23** | **87.08** |
+| MS1MV3    | r100 | -              | 248  | **69.09** | **84.31** |
+| Glint360k | r100 | 3332           | 248  | **75.57** | **90.66** |
+| MS1MV3    | mobilefacenet | 12185 | 7.8  | **41.52** | **65.26** |        
+| Glint360k | mobilefacenet | 11197 | 7.8  | **44.52** | **66.48** |  
+
+### Performance on IJB-C and Verification Datasets
+
+|   Datasets | backbone      | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw  |  log    |
+| :---:      |    :---       | :---          | :---  | :---  |:---   |:---    |:---     |  
+| MS1MV3     | r18      | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|         
+| MS1MV3     | r34      | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|        
+| MS1MV3     | r50      | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|         
+| MS1MV3     | r100     | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|        
+| MS1MV3     | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
+| Glint360k  |r18-0.1   | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)| 
+| Glint360k  |r34-0.1   | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)| 
+| Glint360k  |r50-0.1   | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)| 
+| Glint360k  |r100-0.1  | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
+
+[comment]: <> (More details see [model.md]&#40;docs/modelzoo.md&#41; in docs.)
+
+
+## [Speed Benchmark](docs/speed_benchmark.md)
+
+**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
+classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
+accuracy with several times faster training performance and smaller GPU memory. 
+Partial FC is a sparse variant of the model parallel architecture for large sacle  face recognition. Partial FC use a 
+sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a 
+sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC, 
+we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed 
+training and mixed precision training.
+
+![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
+
+More details see 
+[speed_benchmark.md](docs/speed_benchmark.md) in docs.
+
+### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
+
+`-` means training failed because of gpu memory limitations.
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :---    | :--- | :--- | :--- |
+|125000   | 4681         | 4824          | 5004     |
+|1400000  | **1672**     | 3043          | 4738     |
+|5500000  | **-**        | **1389**      | 3975     |
+|8000000  | **-**        | **-**         | 3565     |
+|16000000 | **-**        | **-**         | 2679     |
+|29000000 | **-**        | **-**         | **1855** |
+
+### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :---    | :---      | :---      | :---  |
+|125000   | 7358      | 5306      | 4868  |
+|1400000  | 32252     | 11178     | 6056  |
+|5500000  | **-**     | 32188     | 9854  |
+|8000000  | **-**     | **-**     | 12310 |
+|16000000 | **-**     | **-**     | 19950 |
+|29000000 | **-**     | **-**     | 32324 |
+
+## Evaluation ICCV2021-MFR and IJB-C
+
+More details see [eval.md](docs/eval.md) in docs.
+
+## Test
+
+We tested many versions of PyTorch. Please create an issue if you are having trouble.  
+
+- [x] torch 1.6.0
+- [x] torch 1.7.1
+- [x] torch 1.8.0
+- [x] torch 1.9.0
+
+## Citation
+
+```
+@inproceedings{deng2019arcface,
+  title={Arcface: Additive angular margin loss for deep face recognition},
+  author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
+  booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
+  pages={4690--4699},
+  year={2019}
+}
+@inproceedings{an2020partical_fc,
+  title={Partial FC: Training 10 Million Identities on a Single Machine},
+  author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
+  Zhang, Debing and Fu Ying},
+  booktitle={Arxiv 2010.05222},
+  year={2020}
+}
+```
diff --git a/src/face3d/models/arcface_torch/backbones/__init__.py b/src/face3d/models/arcface_torch/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5650187b4fdea84c5a23e0445440901690ab682a
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/__init__.py
@@ -0,0 +1,25 @@
+from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
+from .mobilefacenet import get_mbf
+
+
+def get_model(name, **kwargs):
+    # resnet
+    if name == "r18":
+        return iresnet18(False, **kwargs)
+    elif name == "r34":
+        return iresnet34(False, **kwargs)
+    elif name == "r50":
+        return iresnet50(False, **kwargs)
+    elif name == "r100":
+        return iresnet100(False, **kwargs)
+    elif name == "r200":
+        return iresnet200(False, **kwargs)
+    elif name == "r2060":
+        from .iresnet2060 import iresnet2060
+        return iresnet2060(False, **kwargs)
+    elif name == "mbf":
+        fp16 = kwargs.get("fp16", False)
+        num_features = kwargs.get("num_features", 512)
+        return get_mbf(fp16=fp16, num_features=num_features)
+    else:
+        raise ValueError()
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d0d8d575b3e10b10abae1c186ba567384656f841
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..659a2965fda55ec9c29507f2ad2e4b53e6c3faf2
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/__init__.cpython-39.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..89b3c34bf99e575d04b89fec5dc7c77a95a61478
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-38.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..16c75263cf6ffb8547540a799aa871fd8ee88f81
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/iresnet.cpython-39.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..20e575f543838da124b9d87c9bd6fae4bc1aa4a4
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-38.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..726021fe012a771d45c5bd7fa101468c88bf74b3
Binary files /dev/null and b/src/face3d/models/arcface_torch/backbones/__pycache__/mobilefacenet.cpython-39.pyc differ
diff --git a/src/face3d/models/arcface_torch/backbones/iresnet.py b/src/face3d/models/arcface_torch/backbones/iresnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..d29f5f2bfbd444273717c4bc8aa20ba7edd08f80
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/iresnet.py
@@ -0,0 +1,187 @@
+import torch
+from torch import nn
+
+__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes,
+                     out_planes,
+                     kernel_size=3,
+                     stride=stride,
+                     padding=dilation,
+                     groups=groups,
+                     bias=False,
+                     dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes,
+                     out_planes,
+                     kernel_size=1,
+                     stride=stride,
+                     bias=False)
+
+
+class IBasicBlock(nn.Module):
+    expansion = 1
+    def __init__(self, inplanes, planes, stride=1, downsample=None,
+                 groups=1, base_width=64, dilation=1):
+        super(IBasicBlock, self).__init__()
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
+        self.conv1 = conv3x3(inplanes, planes)
+        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
+        self.prelu = nn.PReLU(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+        out = self.bn1(x)
+        out = self.conv1(out)
+        out = self.bn2(out)
+        out = self.prelu(out)
+        out = self.conv2(out)
+        out = self.bn3(out)
+        if self.downsample is not None:
+            identity = self.downsample(x)
+        out += identity
+        return out
+
+
+class IResNet(nn.Module):
+    fc_scale = 7 * 7
+    def __init__(self,
+                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+        super(IResNet, self).__init__()
+        self.fp16 = fp16
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+        self.prelu = nn.PReLU(self.inplanes)
+        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+        self.layer2 = self._make_layer(block,
+                                       128,
+                                       layers[1],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block,
+                                       256,
+                                       layers[2],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block,
+                                       512,
+                                       layers[3],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
+        self.dropout = nn.Dropout(p=dropout, inplace=True)
+        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+        nn.init.constant_(self.features.weight, 1.0)
+        self.features.weight.requires_grad = False
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, 0, 0.1)
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, IBasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+            )
+        layers = []
+        layers.append(
+            block(self.inplanes, planes, stride, downsample, self.groups,
+                  self.base_width, previous_dilation))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(self.inplanes,
+                      planes,
+                      groups=self.groups,
+                      base_width=self.base_width,
+                      dilation=self.dilation))
+
+        return nn.Sequential(*layers)
+
+    def forward(self, x):
+        with torch.cuda.amp.autocast(self.fp16):
+            x = self.conv1(x)
+            x = self.bn1(x)
+            x = self.prelu(x)
+            x = self.layer1(x)
+            x = self.layer2(x)
+            x = self.layer3(x)
+            x = self.layer4(x)
+            x = self.bn2(x)
+            x = torch.flatten(x, 1)
+            x = self.dropout(x)
+        x = self.fc(x.float() if self.fp16 else x)
+        x = self.features(x)
+        return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = IResNet(block, layers, **kwargs)
+    if pretrained:
+        raise ValueError()
+    return model
+
+
+def iresnet18(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
+                    progress, **kwargs)
+
+
+def iresnet34(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
+                    progress, **kwargs)
+
+
+def iresnet50(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
+                    progress, **kwargs)
+
+
+def iresnet100(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
+                    progress, **kwargs)
+
+
+def iresnet200(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
+                    progress, **kwargs)
+
diff --git a/src/face3d/models/arcface_torch/backbones/iresnet2060.py b/src/face3d/models/arcface_torch/backbones/iresnet2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bb4335716b653bd5924e20d616d825ef48339f
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/iresnet2060.py
@@ -0,0 +1,176 @@
+import torch
+from torch import nn
+
+assert torch.__version__ >= "1.8.1"
+from torch.utils.checkpoint import checkpoint_sequential
+
+__all__ = ['iresnet2060']
+
+
+def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes,
+                     out_planes,
+                     kernel_size=3,
+                     stride=stride,
+                     padding=dilation,
+                     groups=groups,
+                     bias=False,
+                     dilation=dilation)
+
+
+def conv1x1(in_planes, out_planes, stride=1):
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes,
+                     out_planes,
+                     kernel_size=1,
+                     stride=stride,
+                     bias=False)
+
+
+class IBasicBlock(nn.Module):
+    expansion = 1
+
+    def __init__(self, inplanes, planes, stride=1, downsample=None,
+                 groups=1, base_width=64, dilation=1):
+        super(IBasicBlock, self).__init__()
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
+        self.conv1 = conv3x3(inplanes, planes)
+        self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
+        self.prelu = nn.PReLU(planes)
+        self.conv2 = conv3x3(planes, planes, stride)
+        self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x):
+        identity = x
+        out = self.bn1(x)
+        out = self.conv1(out)
+        out = self.bn2(out)
+        out = self.prelu(out)
+        out = self.conv2(out)
+        out = self.bn3(out)
+        if self.downsample is not None:
+            identity = self.downsample(x)
+        out += identity
+        return out
+
+
+class IResNet(nn.Module):
+    fc_scale = 7 * 7
+
+    def __init__(self,
+                 block, layers, dropout=0, num_features=512, zero_init_residual=False,
+                 groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
+        super(IResNet, self).__init__()
+        self.fp16 = fp16
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
+        self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
+        self.prelu = nn.PReLU(self.inplanes)
+        self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
+        self.layer2 = self._make_layer(block,
+                                       128,
+                                       layers[1],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block,
+                                       256,
+                                       layers[2],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block,
+                                       512,
+                                       layers[3],
+                                       stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
+        self.dropout = nn.Dropout(p=dropout, inplace=True)
+        self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
+        self.features = nn.BatchNorm1d(num_features, eps=1e-05)
+        nn.init.constant_(self.features.weight, 1.0)
+        self.features.weight.requires_grad = False
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.normal_(m.weight, 0, 0.1)
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, IBasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)
+
+    def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
+            )
+        layers = []
+        layers.append(
+            block(self.inplanes, planes, stride, downsample, self.groups,
+                  self.base_width, previous_dilation))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(
+                block(self.inplanes,
+                      planes,
+                      groups=self.groups,
+                      base_width=self.base_width,
+                      dilation=self.dilation))
+
+        return nn.Sequential(*layers)
+
+    def checkpoint(self, func, num_seg, x):
+        if self.training:
+            return checkpoint_sequential(func, num_seg, x)
+        else:
+            return func(x)
+
+    def forward(self, x):
+        with torch.cuda.amp.autocast(self.fp16):
+            x = self.conv1(x)
+            x = self.bn1(x)
+            x = self.prelu(x)
+            x = self.layer1(x)
+            x = self.checkpoint(self.layer2, 20, x)
+            x = self.checkpoint(self.layer3, 100, x)
+            x = self.layer4(x)
+            x = self.bn2(x)
+            x = torch.flatten(x, 1)
+            x = self.dropout(x)
+        x = self.fc(x.float() if self.fp16 else x)
+        x = self.features(x)
+        return x
+
+
+def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
+    model = IResNet(block, layers, **kwargs)
+    if pretrained:
+        raise ValueError()
+    return model
+
+
+def iresnet2060(pretrained=False, progress=True, **kwargs):
+    return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
diff --git a/src/face3d/models/arcface_torch/backbones/mobilefacenet.py b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c02c6c1e4fa6a6ddf09f5b01dec96971427cb110
--- /dev/null
+++ b/src/face3d/models/arcface_torch/backbones/mobilefacenet.py
@@ -0,0 +1,130 @@
+'''
+Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
+Original author cavalleria
+'''
+
+import torch.nn as nn
+from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
+import torch
+
+
+class Flatten(Module):
+    def forward(self, x):
+        return x.view(x.size(0), -1)
+
+
+class ConvBlock(Module):
+    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+        super(ConvBlock, self).__init__()
+        self.layers = nn.Sequential(
+            Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
+            BatchNorm2d(num_features=out_c),
+            PReLU(num_parameters=out_c)
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class LinearBlock(Module):
+    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
+        super(LinearBlock, self).__init__()
+        self.layers = nn.Sequential(
+            Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
+            BatchNorm2d(num_features=out_c)
+        )
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class DepthWise(Module):
+    def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
+        super(DepthWise, self).__init__()
+        self.residual = residual
+        self.layers = nn.Sequential(
+            ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
+            ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
+            LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
+        )
+
+    def forward(self, x):
+        short_cut = None
+        if self.residual:
+            short_cut = x
+        x = self.layers(x)
+        if self.residual:
+            output = short_cut + x
+        else:
+            output = x
+        return output
+
+
+class Residual(Module):
+    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
+        super(Residual, self).__init__()
+        modules = []
+        for _ in range(num_block):
+            modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
+        self.layers = Sequential(*modules)
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class GDC(Module):
+    def __init__(self, embedding_size):
+        super(GDC, self).__init__()
+        self.layers = nn.Sequential(
+            LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
+            Flatten(),
+            Linear(512, embedding_size, bias=False),
+            BatchNorm1d(embedding_size))
+
+    def forward(self, x):
+        return self.layers(x)
+
+
+class MobileFaceNet(Module):
+    def __init__(self, fp16=False, num_features=512):
+        super(MobileFaceNet, self).__init__()
+        scale = 2
+        self.fp16 = fp16
+        self.layers = nn.Sequential(
+            ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
+            ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
+            DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
+            Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+            DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
+            Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+            DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
+            Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
+        )
+        self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
+        self.features = GDC(num_features)
+        self._initialize_weights()
+
+    def _initialize_weights(self):
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+            elif isinstance(m, nn.BatchNorm2d):
+                m.weight.data.fill_(1)
+                m.bias.data.zero_()
+            elif isinstance(m, nn.Linear):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+                if m.bias is not None:
+                    m.bias.data.zero_()
+
+    def forward(self, x):
+        with torch.cuda.amp.autocast(self.fp16):
+            x = self.layers(x)
+        x = self.conv_sep(x.float() if self.fp16 else x)
+        x = self.features(x)
+        return x
+
+
+def get_mbf(fp16, num_features):
+    return MobileFaceNet(fp16, num_features)
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/configs/3millions.py b/src/face3d/models/arcface_torch/configs/3millions.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bee7cb4236e8b842a1bd1e8c26de7a11df0bf43
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/3millions.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/configs/3millions_pfc.py b/src/face3d/models/arcface_torch/configs/3millions_pfc.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf7df5f04e2509e5dcc14adebbb9302a18f03f2b
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/3millions_pfc.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 300 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/configs/__init__.py b/src/face3d/models/arcface_torch/configs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/configs/base.py b/src/face3d/models/arcface_torch/configs/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..f98c62fed44afde276dcbacecd9da0a8f474963c
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/base.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = "ms1mv3_arcface_r50"
+
+config.dataset = "ms1m-retinaface-t1"
+config.embedding_size = 512
+config.sample_rate = 1
+config.fp16 = False
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+if config.dataset == "emore":
+    config.rec = "/train_tmp/faces_emore"
+    config.num_classes = 85742
+    config.num_image = 5822653
+    config.num_epoch = 16
+    config.warmup_epoch = -1
+    config.decay_epoch = [8, 14, ]
+    config.val_targets = ["lfw", ]
+
+elif config.dataset == "ms1m-retinaface-t1":
+    config.rec = "/train_tmp/ms1m-retinaface-t1"
+    config.num_classes = 93431
+    config.num_image = 5179510
+    config.num_epoch = 25
+    config.warmup_epoch = -1
+    config.decay_epoch = [11, 17, 22]
+    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "glint360k":
+    config.rec = "/train_tmp/glint360k"
+    config.num_classes = 360232
+    config.num_image = 17091657
+    config.num_epoch = 20
+    config.warmup_epoch = -1
+    config.decay_epoch = [8, 12, 15, 18]
+    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
+
+elif config.dataset == "webface":
+    config.rec = "/train_tmp/faces_webface_112x112"
+    config.num_classes = 10572
+    config.num_image = "forget"
+    config.num_epoch = 34
+    config.warmup_epoch = -1
+    config.decay_epoch = [20, 28, 32]
+    config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_mbf.py b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ee5e8d96249d57196df43418f6fda4ab339877
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 0.1
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r100.py b/src/face3d/models/arcface_torch/configs/glint360k_r100.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f8ef745c0efb9d5ea67409edc8c904def8a9d9
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r100.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r100"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r18.py b/src/face3d/models/arcface_torch/configs/glint360k_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..473b59a954fffcaddca132fb6e0f32cbe70c70f4
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r34.py b/src/face3d/models/arcface_torch/configs/glint360k_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9c22ff0c82cc98bbbe81c9a1c26c9b3fc186105
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/glint360k_r50.py b/src/face3d/models/arcface_torch/configs/glint360k_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ecbfda06730e3842e7b347db366e82f0714912f
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/glint360k_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "cosface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/glint360k"
+config.num_classes = 360232
+config.num_image = 17091657
+config.num_epoch = 20
+config.warmup_epoch = -1
+config.decay_epoch = [8, 12, 15, 18]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
new file mode 100644
index 0000000000000000000000000000000000000000..47c87a99867db55c7f689574c331c14cda23ea96
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "mbf"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 2e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 20, 25]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
new file mode 100644
index 0000000000000000000000000000000000000000..1aeb851b05ea22e01da87b3d387812f0253989f8
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r18.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r18"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
new file mode 100644
index 0000000000000000000000000000000000000000..8693e67080dac7e7b84da08a62df326c7b12d465
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r2060.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r2060"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 64
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
new file mode 100644
index 0000000000000000000000000000000000000000..52bff483db179045c0e3acc8e2975477182b0756
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r34.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r34"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
new file mode 100644
index 0000000000000000000000000000000000000000..de81ffdd84edd6fcea7fcb4d3594db031b9e4e26
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/ms1mv3_r50.py
@@ -0,0 +1,26 @@
+from easydict import EasyDict as edict
+
+# make training faster
+# our RAM is 256G
+# mount -t tmpfs -o size=140G  tmpfs /train_tmp
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "/train_tmp/ms1m-retinaface-t1"
+config.num_classes = 93431
+config.num_image = 5179510
+config.num_epoch = 25
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
diff --git a/src/face3d/models/arcface_torch/configs/speed.py b/src/face3d/models/arcface_torch/configs/speed.py
new file mode 100644
index 0000000000000000000000000000000000000000..c172f9d44d39b534f2253630471e91cf78e6fba7
--- /dev/null
+++ b/src/face3d/models/arcface_torch/configs/speed.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict as edict
+
+# configs for test speed
+
+config = edict()
+config.loss = "arcface"
+config.network = "r50"
+config.resume = False
+config.output = None
+config.embedding_size = 512
+config.sample_rate = 1.0
+config.fp16 = True
+config.momentum = 0.9
+config.weight_decay = 5e-4
+config.batch_size = 128
+config.lr = 0.1  # batch size is 512
+
+config.rec = "synthetic"
+config.num_classes = 100 * 10000
+config.num_epoch = 30
+config.warmup_epoch = -1
+config.decay_epoch = [10, 16, 22]
+config.val_targets = []
diff --git a/src/face3d/models/arcface_torch/dataset.py b/src/face3d/models/arcface_torch/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bead250243237c650fa3138f6aa172d4f98535f
--- /dev/null
+++ b/src/face3d/models/arcface_torch/dataset.py
@@ -0,0 +1,124 @@
+import numbers
+import os
+import queue as Queue
+import threading
+
+import mxnet as mx
+import numpy as np
+import torch
+from torch.utils.data import DataLoader, Dataset
+from torchvision import transforms
+
+
+class BackgroundGenerator(threading.Thread):
+    def __init__(self, generator, local_rank, max_prefetch=6):
+        super(BackgroundGenerator, self).__init__()
+        self.queue = Queue.Queue(max_prefetch)
+        self.generator = generator
+        self.local_rank = local_rank
+        self.daemon = True
+        self.start()
+
+    def run(self):
+        torch.cuda.set_device(self.local_rank)
+        for item in self.generator:
+            self.queue.put(item)
+        self.queue.put(None)
+
+    def next(self):
+        next_item = self.queue.get()
+        if next_item is None:
+            raise StopIteration
+        return next_item
+
+    def __next__(self):
+        return self.next()
+
+    def __iter__(self):
+        return self
+
+
+class DataLoaderX(DataLoader):
+
+    def __init__(self, local_rank, **kwargs):
+        super(DataLoaderX, self).__init__(**kwargs)
+        self.stream = torch.cuda.Stream(local_rank)
+        self.local_rank = local_rank
+
+    def __iter__(self):
+        self.iter = super(DataLoaderX, self).__iter__()
+        self.iter = BackgroundGenerator(self.iter, self.local_rank)
+        self.preload()
+        return self
+
+    def preload(self):
+        self.batch = next(self.iter, None)
+        if self.batch is None:
+            return None
+        with torch.cuda.stream(self.stream):
+            for k in range(len(self.batch)):
+                self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True)
+
+    def __next__(self):
+        torch.cuda.current_stream().wait_stream(self.stream)
+        batch = self.batch
+        if batch is None:
+            raise StopIteration
+        self.preload()
+        return batch
+
+
+class MXFaceDataset(Dataset):
+    def __init__(self, root_dir, local_rank):
+        super(MXFaceDataset, self).__init__()
+        self.transform = transforms.Compose(
+            [transforms.ToPILImage(),
+             transforms.RandomHorizontalFlip(),
+             transforms.ToTensor(),
+             transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
+             ])
+        self.root_dir = root_dir
+        self.local_rank = local_rank
+        path_imgrec = os.path.join(root_dir, 'train.rec')
+        path_imgidx = os.path.join(root_dir, 'train.idx')
+        self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, 'r')
+        s = self.imgrec.read_idx(0)
+        header, _ = mx.recordio.unpack(s)
+        if header.flag > 0:
+            self.header0 = (int(header.label[0]), int(header.label[1]))
+            self.imgidx = np.array(range(1, int(header.label[0])))
+        else:
+            self.imgidx = np.array(list(self.imgrec.keys))
+
+    def __getitem__(self, index):
+        idx = self.imgidx[index]
+        s = self.imgrec.read_idx(idx)
+        header, img = mx.recordio.unpack(s)
+        label = header.label
+        if not isinstance(label, numbers.Number):
+            label = label[0]
+        label = torch.tensor(label, dtype=torch.long)
+        sample = mx.image.imdecode(img).asnumpy()
+        if self.transform is not None:
+            sample = self.transform(sample)
+        return sample, label
+
+    def __len__(self):
+        return len(self.imgidx)
+
+
+class SyntheticDataset(Dataset):
+    def __init__(self, local_rank):
+        super(SyntheticDataset, self).__init__()
+        img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+        img = np.transpose(img, (2, 0, 1))
+        img = torch.from_numpy(img).squeeze(0).float()
+        img = ((img / 255) - 0.5) / 0.5
+        self.img = img
+        self.label = 1
+
+    def __getitem__(self, index):
+        return self.img, self.label
+
+    def __len__(self):
+        return 1000000
diff --git a/src/face3d/models/arcface_torch/docs/eval.md b/src/face3d/models/arcface_torch/docs/eval.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d29c855fc6e4245ed264216c1f96ab2efc57248
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/eval.md
@@ -0,0 +1,31 @@
+## Eval on ICCV2021-MFR
+
+coming soon.
+
+
+## Eval IJBC
+You can eval ijbc with pytorch or onnx.
+
+
+1. Eval IJBC With Onnx
+```shell
+CUDA_VISIBLE_DEVICES=0 python onnx_ijbc.py --model-root ms1mv3_arcface_r50 --image-path IJB_release/IJBC --result-dir ms1mv3_arcface_r50
+```
+
+2. Eval IJBC With Pytorch
+```shell
+CUDA_VISIBLE_DEVICES=0,1 python eval_ijbc.py \
+--model-prefix ms1mv3_arcface_r50/backbone.pth \
+--image-path IJB_release/IJBC \
+--result-dir ms1mv3_arcface_r50 \
+--batch-size 128 \
+--job ms1mv3_arcface_r50 \
+--target IJBC \
+--network iresnet50
+```
+
+## Inference
+
+```shell
+python inference.py --weight ms1mv3_arcface_r50/backbone.pth --network r50
+```
diff --git a/src/face3d/models/arcface_torch/docs/install.md b/src/face3d/models/arcface_torch/docs/install.md
new file mode 100644
index 0000000000000000000000000000000000000000..b1b770a0d93dac1f160185b5bbf4da2f414f21f6
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/install.md
@@ -0,0 +1,51 @@
+## v1.8.0 
+### Linux and Windows  
+```shell
+# CUDA 11.0
+pip --default-timeout=100 install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip --default-timeout=100 install torch==1.8.0 torchvision==0.9.0 torchaudio==0.8.0
+
+# CPU only
+pip --default-timeout=100 install torch==1.8.0+cpu torchvision==0.9.0+cpu torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
+
+```
+
+
+## v1.7.1  
+### Linux and Windows  
+```shell
+# CUDA 11.0
+pip install torch==1.7.1+cu110 torchvision==0.8.2+cu110 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 10.2
+pip install torch==1.7.1 torchvision==0.8.2 torchaudio==0.7.2
+
+# CUDA 10.1
+pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.7.1+cu92 torchvision==0.8.2+cu92 torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.7.1+cpu torchvision==0.8.2+cpu torchaudio==0.7.2 -f https://download.pytorch.org/whl/torch_stable.html
+```
+
+
+## v1.6.0  
+
+### Linux and Windows
+```shell
+# CUDA 10.2
+pip install torch==1.6.0 torchvision==0.7.0
+
+# CUDA 10.1
+pip install torch==1.6.0+cu101 torchvision==0.7.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CUDA 9.2
+pip install torch==1.6.0+cu92 torchvision==0.7.0+cu92 -f https://download.pytorch.org/whl/torch_stable.html
+
+# CPU only
+pip install torch==1.6.0+cpu torchvision==0.7.0+cpu -f https://download.pytorch.org/whl/torch_stable.html
+```
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/docs/modelzoo.md b/src/face3d/models/arcface_torch/docs/modelzoo.md
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/docs/speed_benchmark.md b/src/face3d/models/arcface_torch/docs/speed_benchmark.md
new file mode 100644
index 0000000000000000000000000000000000000000..d54904587df4e13784dc68d5709b4d7d97490890
--- /dev/null
+++ b/src/face3d/models/arcface_torch/docs/speed_benchmark.md
@@ -0,0 +1,93 @@
+## Test Training Speed
+
+- Test Commands
+
+You need to use the following two commands to test the Partial FC training performance. 
+The number of identites is **3 millions** (synthetic data), turn mixed precision  training on, backbone is resnet50, 
+batch size is 1024.
+```shell
+# Model Parallel
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions
+# Partial FC 0.1
+python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/3millions_pfc
+```
+
+- GPU Memory
+
+```
+# (Model Parallel) gpustat -i
+[0] Tesla V100-SXM2-32GB | 64'C,  94 % | 30338 / 32510 MB 
+[1] Tesla V100-SXM2-32GB | 60'C,  99 % | 28876 / 32510 MB 
+[2] Tesla V100-SXM2-32GB | 60'C,  99 % | 28872 / 32510 MB 
+[3] Tesla V100-SXM2-32GB | 69'C,  99 % | 28872 / 32510 MB 
+[4] Tesla V100-SXM2-32GB | 66'C,  99 % | 28888 / 32510 MB 
+[5] Tesla V100-SXM2-32GB | 60'C,  99 % | 28932 / 32510 MB 
+[6] Tesla V100-SXM2-32GB | 68'C, 100 % | 28916 / 32510 MB 
+[7] Tesla V100-SXM2-32GB | 65'C,  99 % | 28860 / 32510 MB 
+
+# (Partial FC 0.1) gpustat -i
+[0] Tesla V100-SXM2-32GB | 60'C,  95 % | 10488 / 32510 MB                                                                                                                                          │·······················
+[1] Tesla V100-SXM2-32GB | 60'C,  97 % | 10344 / 32510 MB                                                                                                                                          │·······················
+[2] Tesla V100-SXM2-32GB | 61'C,  95 % | 10340 / 32510 MB                                                                                                                                          │·······················
+[3] Tesla V100-SXM2-32GB | 66'C,  95 % | 10340 / 32510 MB                                                                                                                                          │·······················
+[4] Tesla V100-SXM2-32GB | 65'C,  94 % | 10356 / 32510 MB                                                                                                                                          │·······················
+[5] Tesla V100-SXM2-32GB | 61'C,  95 % | 10400 / 32510 MB                                                                                                                                          │·······················
+[6] Tesla V100-SXM2-32GB | 68'C,  96 % | 10384 / 32510 MB                                                                                                                                          │·······················
+[7] Tesla V100-SXM2-32GB | 64'C,  95 % | 10328 / 32510 MB                                                                                                                                        │·······················
+```
+
+- Training Speed
+
+```python
+# (Model Parallel) trainging.log
+Training: Speed 2271.33 samples/sec   Loss 1.1624   LearningRate 0.2000   Epoch: 0   Global Step: 100 
+Training: Speed 2269.94 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 150 
+Training: Speed 2272.67 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 200 
+Training: Speed 2266.55 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 250 
+Training: Speed 2272.54 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 300 
+
+# (Partial FC 0.1) trainging.log
+Training: Speed 5299.56 samples/sec   Loss 1.0965   LearningRate 0.2000   Epoch: 0   Global Step: 100  
+Training: Speed 5296.37 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 150  
+Training: Speed 5304.37 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 200  
+Training: Speed 5274.43 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 250  
+Training: Speed 5300.10 samples/sec   Loss 0.0000   LearningRate 0.2000   Epoch: 0   Global Step: 300   
+```
+
+In this test case, Partial FC 0.1 only use1 1/3 of the GPU memory of the model parallel, 
+and the training speed is 2.5 times faster than the model parallel.
+
+
+## Speed Benchmark
+
+1. Training speed of different parallel methods (samples/second), Tesla V100 32GB * 8. (Larger is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :---    | :--- | :--- | :--- |
+|125000   | 4681 | 4824 | 5004 |
+|250000   | 4047 | 4521 | 4976 |
+|500000   | 3087 | 4013 | 4900 |
+|1000000  | 2090 | 3449 | 4803 |
+|1400000  | 1672 | 3043 | 4738 |
+|2000000  | -    | 2593 | 4626 |
+|4000000  | -    | 1748 | 4208 |
+|5500000  | -    | 1389 | 3975 |
+|8000000  | -    | -    | 3565 |
+|16000000 | -    | -    | 2679 |
+|29000000 | -    | -    | 1855 |
+
+2. GPU memory cost of different parallel methods (GB per GPU), Tesla V100 32GB * 8. (Smaller is better)
+
+| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
+| :---    | :---  | :---  | :---  |
+|125000   | 7358  | 5306  | 4868  |
+|250000   | 9940  | 5826  | 5004  |
+|500000   | 14220 | 7114  | 5202  |
+|1000000  | 23708 | 9966  | 5620  |
+|1400000  | 32252 | 11178 | 6056  |
+|2000000  | -     | 13978 | 6472  |
+|4000000  | -     | 23238 | 8284  |
+|5500000  | -     | 32188 | 9854  |
+|8000000  | -     | -     | 12310 |
+|16000000 | -     | -     | 19950 |
+|29000000 | -     | -     | 32324 |
diff --git a/src/face3d/models/arcface_torch/eval/__init__.py b/src/face3d/models/arcface_torch/eval/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/eval/verification.py b/src/face3d/models/arcface_torch/eval/verification.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b1f5618184effae64895847af1a65d43d2e4418
--- /dev/null
+++ b/src/face3d/models/arcface_torch/eval/verification.py
@@ -0,0 +1,407 @@
+"""Helper for evaluation on the Labeled Faces in the Wild dataset 
+"""
+
+# MIT License
+#
+# Copyright (c) 2016 David Sandberg
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import datetime
+import os
+import pickle
+
+import mxnet as mx
+import numpy as np
+import sklearn
+import torch
+from mxnet import ndarray as nd
+from scipy import interpolate
+from sklearn.decomposition import PCA
+from sklearn.model_selection import KFold
+
+
+class LFold:
+    def __init__(self, n_splits=2, shuffle=False):
+        self.n_splits = n_splits
+        if self.n_splits > 1:
+            self.k_fold = KFold(n_splits=n_splits, shuffle=shuffle)
+
+    def split(self, indices):
+        if self.n_splits > 1:
+            return self.k_fold.split(indices)
+        else:
+            return [(indices, indices)]
+
+
+def calculate_roc(thresholds,
+                  embeddings1,
+                  embeddings2,
+                  actual_issame,
+                  nrof_folds=10,
+                  pca=0):
+    assert (embeddings1.shape[0] == embeddings2.shape[0])
+    assert (embeddings1.shape[1] == embeddings2.shape[1])
+    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+    nrof_thresholds = len(thresholds)
+    k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+    tprs = np.zeros((nrof_folds, nrof_thresholds))
+    fprs = np.zeros((nrof_folds, nrof_thresholds))
+    accuracy = np.zeros((nrof_folds))
+    indices = np.arange(nrof_pairs)
+
+    if pca == 0:
+        diff = np.subtract(embeddings1, embeddings2)
+        dist = np.sum(np.square(diff), 1)
+
+    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+        if pca > 0:
+            print('doing pca on', fold_idx)
+            embed1_train = embeddings1[train_set]
+            embed2_train = embeddings2[train_set]
+            _embed_train = np.concatenate((embed1_train, embed2_train), axis=0)
+            pca_model = PCA(n_components=pca)
+            pca_model.fit(_embed_train)
+            embed1 = pca_model.transform(embeddings1)
+            embed2 = pca_model.transform(embeddings2)
+            embed1 = sklearn.preprocessing.normalize(embed1)
+            embed2 = sklearn.preprocessing.normalize(embed2)
+            diff = np.subtract(embed1, embed2)
+            dist = np.sum(np.square(diff), 1)
+
+        # Find the best threshold for the fold
+        acc_train = np.zeros((nrof_thresholds))
+        for threshold_idx, threshold in enumerate(thresholds):
+            _, _, acc_train[threshold_idx] = calculate_accuracy(
+                threshold, dist[train_set], actual_issame[train_set])
+        best_threshold_index = np.argmax(acc_train)
+        for threshold_idx, threshold in enumerate(thresholds):
+            tprs[fold_idx, threshold_idx], fprs[fold_idx, threshold_idx], _ = calculate_accuracy(
+                threshold, dist[test_set],
+                actual_issame[test_set])
+        _, _, accuracy[fold_idx] = calculate_accuracy(
+            thresholds[best_threshold_index], dist[test_set],
+            actual_issame[test_set])
+
+    tpr = np.mean(tprs, 0)
+    fpr = np.mean(fprs, 0)
+    return tpr, fpr, accuracy
+
+
+def calculate_accuracy(threshold, dist, actual_issame):
+    predict_issame = np.less(dist, threshold)
+    tp = np.sum(np.logical_and(predict_issame, actual_issame))
+    fp = np.sum(np.logical_and(predict_issame, np.logical_not(actual_issame)))
+    tn = np.sum(
+        np.logical_and(np.logical_not(predict_issame),
+                       np.logical_not(actual_issame)))
+    fn = np.sum(np.logical_and(np.logical_not(predict_issame), actual_issame))
+
+    tpr = 0 if (tp + fn == 0) else float(tp) / float(tp + fn)
+    fpr = 0 if (fp + tn == 0) else float(fp) / float(fp + tn)
+    acc = float(tp + tn) / dist.size
+    return tpr, fpr, acc
+
+
+def calculate_val(thresholds,
+                  embeddings1,
+                  embeddings2,
+                  actual_issame,
+                  far_target,
+                  nrof_folds=10):
+    assert (embeddings1.shape[0] == embeddings2.shape[0])
+    assert (embeddings1.shape[1] == embeddings2.shape[1])
+    nrof_pairs = min(len(actual_issame), embeddings1.shape[0])
+    nrof_thresholds = len(thresholds)
+    k_fold = LFold(n_splits=nrof_folds, shuffle=False)
+
+    val = np.zeros(nrof_folds)
+    far = np.zeros(nrof_folds)
+
+    diff = np.subtract(embeddings1, embeddings2)
+    dist = np.sum(np.square(diff), 1)
+    indices = np.arange(nrof_pairs)
+
+    for fold_idx, (train_set, test_set) in enumerate(k_fold.split(indices)):
+
+        # Find the threshold that gives FAR = far_target
+        far_train = np.zeros(nrof_thresholds)
+        for threshold_idx, threshold in enumerate(thresholds):
+            _, far_train[threshold_idx] = calculate_val_far(
+                threshold, dist[train_set], actual_issame[train_set])
+        if np.max(far_train) >= far_target:
+            f = interpolate.interp1d(far_train, thresholds, kind='slinear')
+            threshold = f(far_target)
+        else:
+            threshold = 0.0
+
+        val[fold_idx], far[fold_idx] = calculate_val_far(
+            threshold, dist[test_set], actual_issame[test_set])
+
+    val_mean = np.mean(val)
+    far_mean = np.mean(far)
+    val_std = np.std(val)
+    return val_mean, val_std, far_mean
+
+
+def calculate_val_far(threshold, dist, actual_issame):
+    predict_issame = np.less(dist, threshold)
+    true_accept = np.sum(np.logical_and(predict_issame, actual_issame))
+    false_accept = np.sum(
+        np.logical_and(predict_issame, np.logical_not(actual_issame)))
+    n_same = np.sum(actual_issame)
+    n_diff = np.sum(np.logical_not(actual_issame))
+    # print(true_accept, false_accept)
+    # print(n_same, n_diff)
+    val = float(true_accept) / float(n_same)
+    far = float(false_accept) / float(n_diff)
+    return val, far
+
+
+def evaluate(embeddings, actual_issame, nrof_folds=10, pca=0):
+    # Calculate evaluation metrics
+    thresholds = np.arange(0, 4, 0.01)
+    embeddings1 = embeddings[0::2]
+    embeddings2 = embeddings[1::2]
+    tpr, fpr, accuracy = calculate_roc(thresholds,
+                                       embeddings1,
+                                       embeddings2,
+                                       np.asarray(actual_issame),
+                                       nrof_folds=nrof_folds,
+                                       pca=pca)
+    thresholds = np.arange(0, 4, 0.001)
+    val, val_std, far = calculate_val(thresholds,
+                                      embeddings1,
+                                      embeddings2,
+                                      np.asarray(actual_issame),
+                                      1e-3,
+                                      nrof_folds=nrof_folds)
+    return tpr, fpr, accuracy, val, val_std, far
+
+@torch.no_grad()
+def load_bin(path, image_size):
+    try:
+        with open(path, 'rb') as f:
+            bins, issame_list = pickle.load(f)  # py2
+    except UnicodeDecodeError as e:
+        with open(path, 'rb') as f:
+            bins, issame_list = pickle.load(f, encoding='bytes')  # py3
+    data_list = []
+    for flip in [0, 1]:
+        data = torch.empty((len(issame_list) * 2, 3, image_size[0], image_size[1]))
+        data_list.append(data)
+    for idx in range(len(issame_list) * 2):
+        _bin = bins[idx]
+        img = mx.image.imdecode(_bin)
+        if img.shape[1] != image_size[0]:
+            img = mx.image.resize_short(img, image_size[0])
+        img = nd.transpose(img, axes=(2, 0, 1))
+        for flip in [0, 1]:
+            if flip == 1:
+                img = mx.ndarray.flip(data=img, axis=2)
+            data_list[flip][idx][:] = torch.from_numpy(img.asnumpy())
+        if idx % 1000 == 0:
+            print('loading bin', idx)
+    print(data_list[0].shape)
+    return data_list, issame_list
+
+@torch.no_grad()
+def test(data_set, backbone, batch_size, nfolds=10):
+    print('testing verification..')
+    data_list = data_set[0]
+    issame_list = data_set[1]
+    embeddings_list = []
+    time_consumed = 0.0
+    for i in range(len(data_list)):
+        data = data_list[i]
+        embeddings = None
+        ba = 0
+        while ba < data.shape[0]:
+            bb = min(ba + batch_size, data.shape[0])
+            count = bb - ba
+            _data = data[bb - batch_size: bb]
+            time0 = datetime.datetime.now()
+            img = ((_data / 255) - 0.5) / 0.5
+            net_out: torch.Tensor = backbone(img)
+            _embeddings = net_out.detach().cpu().numpy()
+            time_now = datetime.datetime.now()
+            diff = time_now - time0
+            time_consumed += diff.total_seconds()
+            if embeddings is None:
+                embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+            embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+            ba = bb
+        embeddings_list.append(embeddings)
+
+    _xnorm = 0.0
+    _xnorm_cnt = 0
+    for embed in embeddings_list:
+        for i in range(embed.shape[0]):
+            _em = embed[i]
+            _norm = np.linalg.norm(_em)
+            _xnorm += _norm
+            _xnorm_cnt += 1
+    _xnorm /= _xnorm_cnt
+
+    acc1 = 0.0
+    std1 = 0.0
+    embeddings = embeddings_list[0] + embeddings_list[1]
+    embeddings = sklearn.preprocessing.normalize(embeddings)
+    print(embeddings.shape)
+    print('infer time', time_consumed)
+    _, _, accuracy, val, val_std, far = evaluate(embeddings, issame_list, nrof_folds=nfolds)
+    acc2, std2 = np.mean(accuracy), np.std(accuracy)
+    return acc1, std1, acc2, std2, _xnorm, embeddings_list
+
+
+def dumpR(data_set,
+          backbone,
+          batch_size,
+          name='',
+          data_extra=None,
+          label_shape=None):
+    print('dump verification embedding..')
+    data_list = data_set[0]
+    issame_list = data_set[1]
+    embeddings_list = []
+    time_consumed = 0.0
+    for i in range(len(data_list)):
+        data = data_list[i]
+        embeddings = None
+        ba = 0
+        while ba < data.shape[0]:
+            bb = min(ba + batch_size, data.shape[0])
+            count = bb - ba
+
+            _data = nd.slice_axis(data, axis=0, begin=bb - batch_size, end=bb)
+            time0 = datetime.datetime.now()
+            if data_extra is None:
+                db = mx.io.DataBatch(data=(_data,), label=(_label,))
+            else:
+                db = mx.io.DataBatch(data=(_data, _data_extra),
+                                     label=(_label,))
+            model.forward(db, is_train=False)
+            net_out = model.get_outputs()
+            _embeddings = net_out[0].asnumpy()
+            time_now = datetime.datetime.now()
+            diff = time_now - time0
+            time_consumed += diff.total_seconds()
+            if embeddings is None:
+                embeddings = np.zeros((data.shape[0], _embeddings.shape[1]))
+            embeddings[ba:bb, :] = _embeddings[(batch_size - count):, :]
+            ba = bb
+        embeddings_list.append(embeddings)
+    embeddings = embeddings_list[0] + embeddings_list[1]
+    embeddings = sklearn.preprocessing.normalize(embeddings)
+    actual_issame = np.asarray(issame_list)
+    outname = os.path.join('temp.bin')
+    with open(outname, 'wb') as f:
+        pickle.dump((embeddings, issame_list),
+                    f,
+                    protocol=pickle.HIGHEST_PROTOCOL)
+
+
+# if __name__ == '__main__':
+#
+#     parser = argparse.ArgumentParser(description='do verification')
+#     # general
+#     parser.add_argument('--data-dir', default='', help='')
+#     parser.add_argument('--model',
+#                         default='../model/softmax,50',
+#                         help='path to load model.')
+#     parser.add_argument('--target',
+#                         default='lfw,cfp_ff,cfp_fp,agedb_30',
+#                         help='test targets.')
+#     parser.add_argument('--gpu', default=0, type=int, help='gpu id')
+#     parser.add_argument('--batch-size', default=32, type=int, help='')
+#     parser.add_argument('--max', default='', type=str, help='')
+#     parser.add_argument('--mode', default=0, type=int, help='')
+#     parser.add_argument('--nfolds', default=10, type=int, help='')
+#     args = parser.parse_args()
+#     image_size = [112, 112]
+#     print('image_size', image_size)
+#     ctx = mx.gpu(args.gpu)
+#     nets = []
+#     vec = args.model.split(',')
+#     prefix = args.model.split(',')[0]
+#     epochs = []
+#     if len(vec) == 1:
+#         pdir = os.path.dirname(prefix)
+#         for fname in os.listdir(pdir):
+#             if not fname.endswith('.params'):
+#                 continue
+#             _file = os.path.join(pdir, fname)
+#             if _file.startswith(prefix):
+#                 epoch = int(fname.split('.')[0].split('-')[1])
+#                 epochs.append(epoch)
+#         epochs = sorted(epochs, reverse=True)
+#         if len(args.max) > 0:
+#             _max = [int(x) for x in args.max.split(',')]
+#             assert len(_max) == 2
+#             if len(epochs) > _max[1]:
+#                 epochs = epochs[_max[0]:_max[1]]
+#
+#     else:
+#         epochs = [int(x) for x in vec[1].split('|')]
+#     print('model number', len(epochs))
+#     time0 = datetime.datetime.now()
+#     for epoch in epochs:
+#         print('loading', prefix, epoch)
+#         sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+#         # arg_params, aux_params = ch_dev(arg_params, aux_params, ctx)
+#         all_layers = sym.get_internals()
+#         sym = all_layers['fc1_output']
+#         model = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+#         # model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0], image_size[1]))], label_shapes=[('softmax_label', (args.batch_size,))])
+#         model.bind(data_shapes=[('data', (args.batch_size, 3, image_size[0],
+#                                           image_size[1]))])
+#         model.set_params(arg_params, aux_params)
+#         nets.append(model)
+#     time_now = datetime.datetime.now()
+#     diff = time_now - time0
+#     print('model loading time', diff.total_seconds())
+#
+#     ver_list = []
+#     ver_name_list = []
+#     for name in args.target.split(','):
+#         path = os.path.join(args.data_dir, name + ".bin")
+#         if os.path.exists(path):
+#             print('loading.. ', name)
+#             data_set = load_bin(path, image_size)
+#             ver_list.append(data_set)
+#             ver_name_list.append(name)
+#
+#     if args.mode == 0:
+#         for i in range(len(ver_list)):
+#             results = []
+#             for model in nets:
+#                 acc1, std1, acc2, std2, xnorm, embeddings_list = test(
+#                     ver_list[i], model, args.batch_size, args.nfolds)
+#                 print('[%s]XNorm: %f' % (ver_name_list[i], xnorm))
+#                 print('[%s]Accuracy: %1.5f+-%1.5f' % (ver_name_list[i], acc1, std1))
+#                 print('[%s]Accuracy-Flip: %1.5f+-%1.5f' % (ver_name_list[i], acc2, std2))
+#                 results.append(acc2)
+#             print('Max of [%s] is %1.5f' % (ver_name_list[i], np.max(results)))
+#     elif args.mode == 1:
+#         raise ValueError
+#     else:
+#         model = nets[0]
+#         dumpR(ver_list[0], model, args.batch_size, args.target)
diff --git a/src/face3d/models/arcface_torch/eval_ijbc.py b/src/face3d/models/arcface_torch/eval_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..64844c4723a88b4b160d2fee9a7b626b987981d9
--- /dev/null
+++ b/src/face3d/models/arcface_torch/eval_ijbc.py
@@ -0,0 +1,483 @@
+# coding: utf-8
+
+import os
+import pickle
+
+import matplotlib
+import pandas as pd
+
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+import timeit
+import sklearn
+import argparse
+import cv2
+import numpy as np
+import torch
+from skimage import transform as trans
+from backbones import get_model
+from sklearn.metrics import roc_curve, auc
+
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from pathlib import Path
+
+import sys
+import warnings
+
+sys.path.insert(0, "../")
+warnings.filterwarnings("ignore")
+
+parser = argparse.ArgumentParser(description='do ijb test')
+# general
+parser.add_argument('--model-prefix', default='', help='path to load model.')
+parser.add_argument('--image-path', default='', type=str, help='')
+parser.add_argument('--result-dir', default='.', type=str, help='')
+parser.add_argument('--batch-size', default=128, type=int, help='')
+parser.add_argument('--network', default='iresnet50', type=str, help='')
+parser.add_argument('--job', default='insightface', type=str, help='job name')
+parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+args = parser.parse_args()
+
+target = args.target
+model_path = args.model_prefix
+image_path = args.image_path
+result_dir = args.result_dir
+gpu_id = None
+use_norm_score = True  # if Ture, TestMode(N1)
+use_detector_score = True  # if Ture, TestMode(D1)
+use_flip_test = True  # if Ture, TestMode(F1)
+job = args.job
+batch_size = args.batch_size
+
+
+class Embedding(object):
+    def __init__(self, prefix, data_shape, batch_size=1):
+        image_size = (112, 112)
+        self.image_size = image_size
+        weight = torch.load(prefix)
+        resnet = get_model(args.network, dropout=0, fp16=False).cuda()
+        resnet.load_state_dict(weight)
+        model = torch.nn.DataParallel(resnet)
+        self.model = model
+        self.model.eval()
+        src = np.array([
+            [30.2946, 51.6963],
+            [65.5318, 51.5014],
+            [48.0252, 71.7366],
+            [33.5493, 92.3655],
+            [62.7299, 92.2041]], dtype=np.float32)
+        src[:, 0] += 8.0
+        self.src = src
+        self.batch_size = batch_size
+        self.data_shape = data_shape
+
+    def get(self, rimg, landmark):
+
+        assert landmark.shape[0] == 68 or landmark.shape[0] == 5
+        assert landmark.shape[1] == 2
+        if landmark.shape[0] == 68:
+            landmark5 = np.zeros((5, 2), dtype=np.float32)
+            landmark5[0] = (landmark[36] + landmark[39]) / 2
+            landmark5[1] = (landmark[42] + landmark[45]) / 2
+            landmark5[2] = landmark[30]
+            landmark5[3] = landmark[48]
+            landmark5[4] = landmark[54]
+        else:
+            landmark5 = landmark
+        tform = trans.SimilarityTransform()
+        tform.estimate(landmark5, self.src)
+        M = tform.params[0:2, :]
+        img = cv2.warpAffine(rimg,
+                             M, (self.image_size[1], self.image_size[0]),
+                             borderValue=0.0)
+        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+        img_flip = np.fliplr(img)
+        img = np.transpose(img, (2, 0, 1))  # 3*112*112, RGB
+        img_flip = np.transpose(img_flip, (2, 0, 1))
+        input_blob = np.zeros((2, 3, self.image_size[1], self.image_size[0]), dtype=np.uint8)
+        input_blob[0] = img
+        input_blob[1] = img_flip
+        return input_blob
+
+    @torch.no_grad()
+    def forward_db(self, batch_data):
+        imgs = torch.Tensor(batch_data).cuda()
+        imgs.div_(255).sub_(0.5).div_(0.5)
+        feat = self.model(imgs)
+        feat = feat.reshape([self.batch_size, 2 * feat.shape[1]])
+        return feat.cpu().numpy()
+
+
+# 将一个list尽量均分成n份,限制len(list)==n,份数大于原list内元素个数则分配空list[]
+def divideIntoNstrand(listTemp, n):
+    twoList = [[] for i in range(n)]
+    for i, e in enumerate(listTemp):
+        twoList[i % n].append(e)
+    return twoList
+
+
+def read_template_media_list(path):
+    # ijb_meta = np.loadtxt(path, dtype=str)
+    ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+    templates = ijb_meta[:, 1].astype(np.int)
+    medias = ijb_meta[:, 2].astype(np.int)
+    return templates, medias
+
+
+# In[ ]:
+
+
+def read_template_pair_list(path):
+    # pairs = np.loadtxt(path, dtype=str)
+    pairs = pd.read_csv(path, sep=' ', header=None).values
+    # print(pairs.shape)
+    # print(pairs[:, 0].astype(np.int))
+    t1 = pairs[:, 0].astype(np.int)
+    t2 = pairs[:, 1].astype(np.int)
+    label = pairs[:, 2].astype(np.int)
+    return t1, t2, label
+
+
+# In[ ]:
+
+
+def read_image_feature(path):
+    with open(path, 'rb') as fid:
+        img_feats = pickle.load(fid)
+    return img_feats
+
+
+# In[ ]:
+
+
+def get_image_feature(img_path, files_list, model_path, epoch, gpu_id):
+    batch_size = args.batch_size
+    data_shape = (3, 112, 112)
+
+    files = files_list
+    print('files:', len(files))
+    rare_size = len(files) % batch_size
+    faceness_scores = []
+    batch = 0
+    img_feats = np.empty((len(files), 1024), dtype=np.float32)
+
+    batch_data = np.empty((2 * batch_size, 3, 112, 112))
+    embedding = Embedding(model_path, data_shape, batch_size)
+    for img_index, each_line in enumerate(files[:len(files) - rare_size]):
+        name_lmk_score = each_line.strip().split(' ')
+        img_name = os.path.join(img_path, name_lmk_score[0])
+        img = cv2.imread(img_name)
+        lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+                       dtype=np.float32)
+        lmk = lmk.reshape((5, 2))
+        input_blob = embedding.get(img, lmk)
+
+        batch_data[2 * (img_index - batch * batch_size)][:] = input_blob[0]
+        batch_data[2 * (img_index - batch * batch_size) + 1][:] = input_blob[1]
+        if (img_index + 1) % batch_size == 0:
+            print('batch', batch)
+            img_feats[batch * batch_size:batch * batch_size +
+                                         batch_size][:] = embedding.forward_db(batch_data)
+            batch += 1
+        faceness_scores.append(name_lmk_score[-1])
+
+    batch_data = np.empty((2 * rare_size, 3, 112, 112))
+    embedding = Embedding(model_path, data_shape, rare_size)
+    for img_index, each_line in enumerate(files[len(files) - rare_size:]):
+        name_lmk_score = each_line.strip().split(' ')
+        img_name = os.path.join(img_path, name_lmk_score[0])
+        img = cv2.imread(img_name)
+        lmk = np.array([float(x) for x in name_lmk_score[1:-1]],
+                       dtype=np.float32)
+        lmk = lmk.reshape((5, 2))
+        input_blob = embedding.get(img, lmk)
+        batch_data[2 * img_index][:] = input_blob[0]
+        batch_data[2 * img_index + 1][:] = input_blob[1]
+        if (img_index + 1) % rare_size == 0:
+            print('batch', batch)
+            img_feats[len(files) -
+                      rare_size:][:] = embedding.forward_db(batch_data)
+            batch += 1
+        faceness_scores.append(name_lmk_score[-1])
+    faceness_scores = np.array(faceness_scores).astype(np.float32)
+    # img_feats = np.ones( (len(files), 1024), dtype=np.float32) * 0.01
+    # faceness_scores = np.ones( (len(files), ), dtype=np.float32 )
+    return img_feats, faceness_scores
+
+
+# In[ ]:
+
+
+def image2template_feature(img_feats=None, templates=None, medias=None):
+    # ==========================================================
+    # 1. face image feature l2 normalization. img_feats:[number_image x feats_dim]
+    # 2. compute media feature.
+    # 3. compute template feature.
+    # ==========================================================
+    unique_templates = np.unique(templates)
+    template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+
+    for count_template, uqt in enumerate(unique_templates):
+
+        (ind_t,) = np.where(templates == uqt)
+        face_norm_feats = img_feats[ind_t]
+        face_medias = medias[ind_t]
+        unique_medias, unique_media_counts = np.unique(face_medias,
+                                                       return_counts=True)
+        media_norm_feats = []
+        for u, ct in zip(unique_medias, unique_media_counts):
+            (ind_m,) = np.where(face_medias == u)
+            if ct == 1:
+                media_norm_feats += [face_norm_feats[ind_m]]
+            else:  # image features from the same video will be aggregated into one feature
+                media_norm_feats += [
+                    np.mean(face_norm_feats[ind_m], axis=0, keepdims=True)
+                ]
+        media_norm_feats = np.array(media_norm_feats)
+        # media_norm_feats = media_norm_feats / np.sqrt(np.sum(media_norm_feats ** 2, -1, keepdims=True))
+        template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+        if count_template % 2000 == 0:
+            print('Finish Calculating {} template features.'.format(
+                count_template))
+    # template_norm_feats = template_feats / np.sqrt(np.sum(template_feats ** 2, -1, keepdims=True))
+    template_norm_feats = sklearn.preprocessing.normalize(template_feats)
+    # print(template_norm_feats.shape)
+    return template_norm_feats, unique_templates
+
+
+# In[ ]:
+
+
+def verification(template_norm_feats=None,
+                 unique_templates=None,
+                 p1=None,
+                 p2=None):
+    # ==========================================================
+    #         Compute set-to-set Similarity Score.
+    # ==========================================================
+    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+    for count_template, uqt in enumerate(unique_templates):
+        template2id[uqt] = count_template
+
+    score = np.zeros((len(p1),))  # save cosine distance between pairs
+
+    total_pairs = np.array(range(len(p1)))
+    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation
+    sublists = [
+        total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+    ]
+    total_sublists = len(sublists)
+    for c, s in enumerate(sublists):
+        feat1 = template_norm_feats[template2id[p1[s]]]
+        feat2 = template_norm_feats[template2id[p2[s]]]
+        similarity_score = np.sum(feat1 * feat2, -1)
+        score[s] = similarity_score.flatten()
+        if c % 10 == 0:
+            print('Finish {}/{} pairs.'.format(c, total_sublists))
+    return score
+
+
+# In[ ]:
+def verification2(template_norm_feats=None,
+                  unique_templates=None,
+                  p1=None,
+                  p2=None):
+    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+    for count_template, uqt in enumerate(unique_templates):
+        template2id[uqt] = count_template
+    score = np.zeros((len(p1),))  # save cosine distance between pairs
+    total_pairs = np.array(range(len(p1)))
+    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation
+    sublists = [
+        total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)
+    ]
+    total_sublists = len(sublists)
+    for c, s in enumerate(sublists):
+        feat1 = template_norm_feats[template2id[p1[s]]]
+        feat2 = template_norm_feats[template2id[p2[s]]]
+        similarity_score = np.sum(feat1 * feat2, -1)
+        score[s] = similarity_score.flatten()
+        if c % 10 == 0:
+            print('Finish {}/{} pairs.'.format(c, total_sublists))
+    return score
+
+
+def read_score(path):
+    with open(path, 'rb') as fid:
+        img_feats = pickle.load(fid)
+    return img_feats
+
+
+# # Step1: Load Meta Data
+
+# In[ ]:
+
+assert target == 'IJBC' or target == 'IJBB'
+
+# =============================================================
+# load image and template relationships for template feature embedding
+# tid --> template id,  mid --> media id
+# format:
+#           image_name tid mid
+# =============================================================
+start = timeit.default_timer()
+templates, medias = read_template_media_list(
+    os.path.join('%s/meta' % image_path,
+                 '%s_face_tid_mid.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+
+# =============================================================
+# load template pairs for template-to-template verification
+# tid : template id,  label : 1/0
+# format:
+#           tid_1 tid_2 label
+# =============================================================
+start = timeit.default_timer()
+p1, p2, label = read_template_pair_list(
+    os.path.join('%s/meta' % image_path,
+                 '%s_template_pair_label.txt' % target.lower()))
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 2: Get Image Features
+
+# In[ ]:
+
+# =============================================================
+# load image features
+# format:
+#           img_feats: [image_num x feats_dim] (227630, 512)
+# =============================================================
+start = timeit.default_timer()
+img_path = '%s/loose_crop' % image_path
+img_list_path = '%s/meta/%s_name_5pts_score.txt' % (image_path, target.lower())
+img_list = open(img_list_path)
+files = img_list.readlines()
+# files_list = divideIntoNstrand(files, rank_size)
+files_list = files
+
+# img_feats
+# for i in range(rank_size):
+img_feats, faceness_scores = get_image_feature(img_path, files_list,
+                                               model_path, 0, gpu_id)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0],
+                                          img_feats.shape[1]))
+
+# # Step3: Get Template Features
+
+# In[ ]:
+
+# =============================================================
+# compute template features from image features.
+# =============================================================
+start = timeit.default_timer()
+# ==========================================================
+# Norm feature before aggregation into template feature?
+# Feature norm from embedding network and faceness score are able to decrease weights for noise samples (not face).
+# ==========================================================
+# 1. FaceScore (Feature Norm)
+# 2. FaceScore (Detector)
+
+if use_flip_test:
+    # concat --- F1
+    # img_input_feats = img_feats
+    # add --- F2
+    img_input_feats = img_feats[:, 0:img_feats.shape[1] //
+                                     2] + img_feats[:, img_feats.shape[1] // 2:]
+else:
+    img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+if use_norm_score:
+    img_input_feats = img_input_feats
+else:
+    # normalise features to remove norm information
+    img_input_feats = img_input_feats / np.sqrt(
+        np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+if use_detector_score:
+    print(img_input_feats.shape, faceness_scores.shape)
+    img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+else:
+    img_input_feats = img_input_feats
+
+template_norm_feats, unique_templates = image2template_feature(
+    img_input_feats, templates, medias)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# # Step 4: Get Template Similarity Scores
+
+# In[ ]:
+
+# =============================================================
+# compute verification scores between template pairs.
+# =============================================================
+start = timeit.default_timer()
+score = verification(template_norm_feats, unique_templates, p1, p2)
+stop = timeit.default_timer()
+print('Time: %.2f s. ' % (stop - start))
+
+# In[ ]:
+save_path = os.path.join(result_dir, args.job)
+# save_path = result_dir + '/%s_result' % target
+
+if not os.path.exists(save_path):
+    os.makedirs(save_path)
+
+score_save_file = os.path.join(save_path, "%s.npy" % target.lower())
+np.save(score_save_file, score)
+
+# # Step 5: Get ROC Curves and TPR@FPR Table
+
+# In[ ]:
+
+files = [score_save_file]
+methods = []
+scores = []
+for file in files:
+    methods.append(Path(file).stem)
+    scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+    zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+    fpr, tpr, _ = roc_curve(label, scores[method])
+    roc_auc = auc(fpr, tpr)
+    fpr = np.flipud(fpr)
+    tpr = np.flipud(tpr)  # select largest tpr at same fpr
+    plt.plot(fpr,
+             tpr,
+             color=colours[method],
+             lw=1,
+             label=('[%s (AUC = %0.4f %%)]' %
+                    (method.split('-')[-1], roc_auc * 100)))
+    tpr_fpr_row = []
+    tpr_fpr_row.append("%s-%s" % (method, target))
+    for fpr_iter in np.arange(len(x_labels)):
+        _, min_index = min(
+            list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+        tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+    tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+fig.savefig(os.path.join(save_path, '%s.pdf' % target.lower()))
+print(tpr_fpr_table)
diff --git a/src/face3d/models/arcface_torch/inference.py b/src/face3d/models/arcface_torch/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..1929d4abb640d040398dda57b491b9bd96deac9d
--- /dev/null
+++ b/src/face3d/models/arcface_torch/inference.py
@@ -0,0 +1,35 @@
+import argparse
+
+import cv2
+import numpy as np
+import torch
+
+from backbones import get_model
+
+
+@torch.no_grad()
+def inference(weight, name, img):
+    if img is None:
+        img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.uint8)
+    else:
+        img = cv2.imread(img)
+        img = cv2.resize(img, (112, 112))
+
+    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+    img = np.transpose(img, (2, 0, 1))
+    img = torch.from_numpy(img).unsqueeze(0).float()
+    img.div_(255).sub_(0.5).div_(0.5)
+    net = get_model(name, fp16=False)
+    net.load_state_dict(torch.load(weight))
+    net.eval()
+    feat = net(img).numpy()
+    print(feat)
+
+
+if __name__ == "__main__":
+    parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+    parser.add_argument('--network', type=str, default='r50', help='backbone network')
+    parser.add_argument('--weight', type=str, default='')
+    parser.add_argument('--img', type=str, default=None)
+    args = parser.parse_args()
+    inference(args.weight, args.network, args.img)
diff --git a/src/face3d/models/arcface_torch/losses.py b/src/face3d/models/arcface_torch/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bfdd8c6b7f6b0d465928f19c554e62340e5ad7b
--- /dev/null
+++ b/src/face3d/models/arcface_torch/losses.py
@@ -0,0 +1,42 @@
+import torch
+from torch import nn
+
+
+def get_loss(name):
+    if name == "cosface":
+        return CosFace()
+    elif name == "arcface":
+        return ArcFace()
+    else:
+        raise ValueError()
+
+
+class CosFace(nn.Module):
+    def __init__(self, s=64.0, m=0.40):
+        super(CosFace, self).__init__()
+        self.s = s
+        self.m = m
+
+    def forward(self, cosine, label):
+        index = torch.where(label != -1)[0]
+        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+        m_hot.scatter_(1, label[index, None], self.m)
+        cosine[index] -= m_hot
+        ret = cosine * self.s
+        return ret
+
+
+class ArcFace(nn.Module):
+    def __init__(self, s=64.0, m=0.5):
+        super(ArcFace, self).__init__()
+        self.s = s
+        self.m = m
+
+    def forward(self, cosine: torch.Tensor, label):
+        index = torch.where(label != -1)[0]
+        m_hot = torch.zeros(index.size()[0], cosine.size()[1], device=cosine.device)
+        m_hot.scatter_(1, label[index, None], self.m)
+        cosine.acos_()
+        cosine[index] += m_hot
+        cosine.cos_().mul_(self.s)
+        return cosine
diff --git a/src/face3d/models/arcface_torch/onnx_helper.py b/src/face3d/models/arcface_torch/onnx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a01a46621dc0ea695bd903de5d1e212d424c860
--- /dev/null
+++ b/src/face3d/models/arcface_torch/onnx_helper.py
@@ -0,0 +1,250 @@
+from __future__ import division
+import datetime
+import os
+import os.path as osp
+import glob
+import numpy as np
+import cv2
+import sys
+import onnxruntime
+import onnx
+import argparse
+from onnx import numpy_helper
+from insightface.data import get_image
+
+class ArcFaceORT:
+    def __init__(self, model_path, cpu=False):
+        self.model_path = model_path
+        # providers = None will use available provider, for onnxruntime-gpu it will be "CUDAExecutionProvider"
+        self.providers = ['CPUExecutionProvider'] if cpu else None
+
+    #input_size is (w,h), return error message, return None if success
+    def check(self, track='cfat', test_img = None):
+        #default is cfat
+        max_model_size_mb=1024
+        max_feat_dim=512
+        max_time_cost=15
+        if track.startswith('ms1m'):
+            max_model_size_mb=1024
+            max_feat_dim=512
+            max_time_cost=10
+        elif track.startswith('glint'):
+            max_model_size_mb=1024
+            max_feat_dim=1024
+            max_time_cost=20
+        elif track.startswith('cfat'):
+            max_model_size_mb = 1024
+            max_feat_dim = 512
+            max_time_cost = 15
+        elif track.startswith('unconstrained'):
+            max_model_size_mb=1024
+            max_feat_dim=1024
+            max_time_cost=30
+        else:
+            return "track not found"
+
+        if not os.path.exists(self.model_path):
+            return "model_path not exists"
+        if not os.path.isdir(self.model_path):
+            return "model_path should be directory"
+        onnx_files = []
+        for _file in os.listdir(self.model_path):
+            if _file.endswith('.onnx'):
+                onnx_files.append(osp.join(self.model_path, _file))
+        if len(onnx_files)==0:
+            return "do not have onnx files"
+        self.model_file = sorted(onnx_files)[-1]
+        print('use onnx-model:', self.model_file)
+        try:
+            session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+        except:
+            return "load onnx failed"
+        input_cfg = session.get_inputs()[0]
+        input_shape = input_cfg.shape
+        print('input-shape:', input_shape)
+        if len(input_shape)!=4:
+            return "length of input_shape should be 4"
+        if not isinstance(input_shape[0], str):
+            #return "input_shape[0] should be str to support batch-inference"
+            print('reset input-shape[0] to None')
+            model = onnx.load(self.model_file)
+            model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+            new_model_file = osp.join(self.model_path, 'zzzzrefined.onnx')
+            onnx.save(model, new_model_file)
+            self.model_file = new_model_file
+            print('use new onnx-model:', self.model_file)
+            try:
+                session = onnxruntime.InferenceSession(self.model_file, providers=self.providers)
+            except:
+                return "load onnx failed"
+            input_cfg = session.get_inputs()[0]
+            input_shape = input_cfg.shape
+            print('new-input-shape:', input_shape)
+
+        self.image_size = tuple(input_shape[2:4][::-1])
+        #print('image_size:', self.image_size)
+        input_name = input_cfg.name
+        outputs = session.get_outputs()
+        output_names = []
+        for o in outputs:
+            output_names.append(o.name)
+            #print(o.name, o.shape)
+        if len(output_names)!=1:
+            return "number of output nodes should be 1"
+        self.session = session
+        self.input_name = input_name
+        self.output_names = output_names
+        #print(self.output_names)
+        model = onnx.load(self.model_file)
+        graph = model.graph
+        if len(graph.node)<8:
+            return "too small onnx graph"
+
+        input_size = (112,112)
+        self.crop = None
+        if track=='cfat':
+            crop_file = osp.join(self.model_path, 'crop.txt')
+            if osp.exists(crop_file):
+                lines = open(crop_file,'r').readlines()
+                if len(lines)!=6:
+                    return "crop.txt should contain 6 lines"
+                lines = [int(x) for x in lines]
+                self.crop = lines[:4]
+                input_size = tuple(lines[4:6])
+        if input_size!=self.image_size:
+            return "input-size is inconsistant with onnx model input, %s vs %s"%(input_size, self.image_size)
+
+        self.model_size_mb = os.path.getsize(self.model_file) / float(1024*1024)
+        if self.model_size_mb > max_model_size_mb:
+            return "max model size exceed, given %.3f-MB"%self.model_size_mb
+
+        input_mean = None
+        input_std = None
+        if track=='cfat':
+            pn_file = osp.join(self.model_path, 'pixel_norm.txt')
+            if osp.exists(pn_file):
+                lines = open(pn_file,'r').readlines()
+                if len(lines)!=2:
+                    return "pixel_norm.txt should contain 2 lines"
+                input_mean = float(lines[0])
+                input_std = float(lines[1])
+        if input_mean is not None or input_std is not None:
+            if input_mean is None or input_std is None:
+                return "please set input_mean and input_std simultaneously"
+        else:
+            find_sub = False
+            find_mul = False
+            for nid, node in enumerate(graph.node[:8]):
+                print(nid, node.name)
+                if node.name.startswith('Sub') or node.name.startswith('_minus'):
+                    find_sub = True
+                if node.name.startswith('Mul') or node.name.startswith('_mul') or node.name.startswith('Div'):
+                    find_mul = True
+            if find_sub and find_mul:
+                print("find sub and mul")
+                #mxnet arcface model
+                input_mean = 0.0
+                input_std = 1.0
+            else:
+                input_mean = 127.5
+                input_std = 127.5
+        self.input_mean = input_mean
+        self.input_std = input_std
+        for initn in graph.initializer:
+            weight_array = numpy_helper.to_array(initn)
+            dt = weight_array.dtype
+            if dt.itemsize<4:
+                return 'invalid weight type - (%s:%s)' % (initn.name, dt.name)
+        if test_img is None:
+            test_img = get_image('Tom_Hanks_54745')
+            test_img = cv2.resize(test_img, self.image_size)
+        else:
+            test_img = cv2.resize(test_img, self.image_size)
+        feat, cost = self.benchmark(test_img)
+        batch_result = self.check_batch(test_img)
+        batch_result_sum = float(np.sum(batch_result))
+        if batch_result_sum in [float('inf'), -float('inf')] or batch_result_sum != batch_result_sum:
+            print(batch_result)
+            print(batch_result_sum)
+            return "batch result output contains NaN!"
+
+        if len(feat.shape) < 2:
+           return "the shape of the feature must be two, but get {}".format(str(feat.shape))
+
+        if feat.shape[1] > max_feat_dim:
+            return "max feat dim exceed, given %d"%feat.shape[1]
+        self.feat_dim = feat.shape[1]
+        cost_ms = cost*1000
+        if cost_ms>max_time_cost:
+            return "max time cost exceed, given %.4f"%cost_ms
+        self.cost_ms = cost_ms
+        print('check stat:, model-size-mb: %.4f, feat-dim: %d, time-cost-ms: %.4f, input-mean: %.3f, input-std: %.3f'%(self.model_size_mb, self.feat_dim, self.cost_ms, self.input_mean, self.input_std))
+        return None
+
+    def check_batch(self, img):
+        if not isinstance(img, list):
+            imgs = [img, ] * 32
+        if self.crop is not None:
+            nimgs = []
+            for img in imgs:
+                nimg = img[self.crop[1]:self.crop[3], self.crop[0]:self.crop[2], :]
+                if nimg.shape[0] != self.image_size[1] or nimg.shape[1] != self.image_size[0]:
+                    nimg = cv2.resize(nimg, self.image_size)
+                nimgs.append(nimg)
+            imgs = nimgs
+        blob = cv2.dnn.blobFromImages(
+            images=imgs, scalefactor=1.0 / self.input_std, size=self.image_size,
+            mean=(self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+        net_out = self.session.run(self.output_names, {self.input_name: blob})[0]
+        return net_out
+
+
+    def meta_info(self):
+        return {'model-size-mb':self.model_size_mb, 'feature-dim':self.feat_dim, 'infer': self.cost_ms}
+
+
+    def forward(self, imgs):
+        if not isinstance(imgs, list):
+            imgs = [imgs]
+        input_size = self.image_size
+        if self.crop is not None:
+            nimgs = []
+            for img in imgs:
+                nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+                if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+                    nimg = cv2.resize(nimg, input_size)
+                nimgs.append(nimg)
+            imgs = nimgs
+        blob = cv2.dnn.blobFromImages(imgs, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+        net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+        return net_out
+
+    def benchmark(self, img):
+        input_size = self.image_size
+        if self.crop is not None:
+            nimg = img[self.crop[1]:self.crop[3],self.crop[0]:self.crop[2],:]
+            if nimg.shape[0]!=input_size[1] or nimg.shape[1]!=input_size[0]:
+                nimg = cv2.resize(nimg, input_size)
+            img = nimg
+        blob = cv2.dnn.blobFromImage(img, 1.0/self.input_std, input_size, (self.input_mean, self.input_mean, self.input_mean), swapRB=True)
+        costs = []
+        for _ in range(50):
+            ta = datetime.datetime.now()
+            net_out = self.session.run(self.output_names, {self.input_name : blob})[0]
+            tb = datetime.datetime.now()
+            cost = (tb-ta).total_seconds()
+            costs.append(cost)
+        costs = sorted(costs)
+        cost = costs[5]
+        return net_out, cost
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='')
+    # general
+    parser.add_argument('workdir', help='submitted work dir', type=str)
+    parser.add_argument('--track', help='track name, for different challenge', type=str, default='cfat')
+    args = parser.parse_args()
+    handler = ArcFaceORT(args.workdir)
+    err = handler.check(args.track)
+    print('err:', err)
diff --git a/src/face3d/models/arcface_torch/onnx_ijbc.py b/src/face3d/models/arcface_torch/onnx_ijbc.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa96b96745e23d4d6642d99f71456c10af5e4e4e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/onnx_ijbc.py
@@ -0,0 +1,267 @@
+import argparse
+import os
+import pickle
+import timeit
+
+import cv2
+import mxnet as mx
+import numpy as np
+import pandas as pd
+import prettytable
+import skimage.transform
+from sklearn.metrics import roc_curve
+from sklearn.preprocessing import normalize
+
+from onnx_helper import ArcFaceORT
+
+SRC = np.array(
+    [
+        [30.2946, 51.6963],
+        [65.5318, 51.5014],
+        [48.0252, 71.7366],
+        [33.5493, 92.3655],
+        [62.7299, 92.2041]]
+    , dtype=np.float32)
+SRC[:, 0] += 8.0
+
+
+class AlignedDataSet(mx.gluon.data.Dataset):
+    def __init__(self, root, lines, align=True):
+        self.lines = lines
+        self.root = root
+        self.align = align
+
+    def __len__(self):
+        return len(self.lines)
+
+    def __getitem__(self, idx):
+        each_line = self.lines[idx]
+        name_lmk_score = each_line.strip().split(' ')
+        name = os.path.join(self.root, name_lmk_score[0])
+        img = cv2.cvtColor(cv2.imread(name), cv2.COLOR_BGR2RGB)
+        landmark5 = np.array([float(x) for x in name_lmk_score[1:-1]], dtype=np.float32).reshape((5, 2))
+        st = skimage.transform.SimilarityTransform()
+        st.estimate(landmark5, SRC)
+        img = cv2.warpAffine(img, st.params[0:2, :], (112, 112), borderValue=0.0)
+        img_1 = np.expand_dims(img, 0)
+        img_2 = np.expand_dims(np.fliplr(img), 0)
+        output = np.concatenate((img_1, img_2), axis=0).astype(np.float32)
+        output = np.transpose(output, (0, 3, 1, 2))
+        output = mx.nd.array(output)
+        return output
+
+
+def extract(model_root, dataset):
+    model = ArcFaceORT(model_path=model_root)
+    model.check()
+    feat_mat = np.zeros(shape=(len(dataset), 2 * model.feat_dim))
+
+    def batchify_fn(data):
+        return mx.nd.concat(*data, dim=0)
+
+    data_loader = mx.gluon.data.DataLoader(
+        dataset, 128, last_batch='keep', num_workers=4,
+        thread_pool=True, prefetch=16, batchify_fn=batchify_fn)
+    num_iter = 0
+    for batch in data_loader:
+        batch = batch.asnumpy()
+        batch = (batch - model.input_mean) / model.input_std
+        feat = model.session.run(model.output_names, {model.input_name: batch})[0]
+        feat = np.reshape(feat, (-1, model.feat_dim * 2))
+        feat_mat[128 * num_iter: 128 * num_iter + feat.shape[0], :] = feat
+        num_iter += 1
+        if num_iter % 50 == 0:
+            print(num_iter)
+    return feat_mat
+
+
+def read_template_media_list(path):
+    ijb_meta = pd.read_csv(path, sep=' ', header=None).values
+    templates = ijb_meta[:, 1].astype(np.int)
+    medias = ijb_meta[:, 2].astype(np.int)
+    return templates, medias
+
+
+def read_template_pair_list(path):
+    pairs = pd.read_csv(path, sep=' ', header=None).values
+    t1 = pairs[:, 0].astype(np.int)
+    t2 = pairs[:, 1].astype(np.int)
+    label = pairs[:, 2].astype(np.int)
+    return t1, t2, label
+
+
+def read_image_feature(path):
+    with open(path, 'rb') as fid:
+        img_feats = pickle.load(fid)
+    return img_feats
+
+
+def image2template_feature(img_feats=None,
+                           templates=None,
+                           medias=None):
+    unique_templates = np.unique(templates)
+    template_feats = np.zeros((len(unique_templates), img_feats.shape[1]))
+    for count_template, uqt in enumerate(unique_templates):
+        (ind_t,) = np.where(templates == uqt)
+        face_norm_feats = img_feats[ind_t]
+        face_medias = medias[ind_t]
+        unique_medias, unique_media_counts = np.unique(face_medias, return_counts=True)
+        media_norm_feats = []
+        for u, ct in zip(unique_medias, unique_media_counts):
+            (ind_m,) = np.where(face_medias == u)
+            if ct == 1:
+                media_norm_feats += [face_norm_feats[ind_m]]
+            else:  # image features from the same video will be aggregated into one feature
+                media_norm_feats += [np.mean(face_norm_feats[ind_m], axis=0, keepdims=True), ]
+        media_norm_feats = np.array(media_norm_feats)
+        template_feats[count_template] = np.sum(media_norm_feats, axis=0)
+        if count_template % 2000 == 0:
+            print('Finish Calculating {} template features.'.format(
+                count_template))
+    template_norm_feats = normalize(template_feats)
+    return template_norm_feats, unique_templates
+
+
+def verification(template_norm_feats=None,
+                 unique_templates=None,
+                 p1=None,
+                 p2=None):
+    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+    for count_template, uqt in enumerate(unique_templates):
+        template2id[uqt] = count_template
+    score = np.zeros((len(p1),))
+    total_pairs = np.array(range(len(p1)))
+    batchsize = 100000
+    sublists = [total_pairs[i: i + batchsize] for i in range(0, len(p1), batchsize)]
+    total_sublists = len(sublists)
+    for c, s in enumerate(sublists):
+        feat1 = template_norm_feats[template2id[p1[s]]]
+        feat2 = template_norm_feats[template2id[p2[s]]]
+        similarity_score = np.sum(feat1 * feat2, -1)
+        score[s] = similarity_score.flatten()
+        if c % 10 == 0:
+            print('Finish {}/{} pairs.'.format(c, total_sublists))
+    return score
+
+
+def verification2(template_norm_feats=None,
+                  unique_templates=None,
+                  p1=None,
+                  p2=None):
+    template2id = np.zeros((max(unique_templates) + 1, 1), dtype=int)
+    for count_template, uqt in enumerate(unique_templates):
+        template2id[uqt] = count_template
+    score = np.zeros((len(p1),))  # save cosine distance between pairs
+    total_pairs = np.array(range(len(p1)))
+    batchsize = 100000  # small batchsize instead of all pairs in one batch due to the memory limiation
+    sublists = [total_pairs[i:i + batchsize] for i in range(0, len(p1), batchsize)]
+    total_sublists = len(sublists)
+    for c, s in enumerate(sublists):
+        feat1 = template_norm_feats[template2id[p1[s]]]
+        feat2 = template_norm_feats[template2id[p2[s]]]
+        similarity_score = np.sum(feat1 * feat2, -1)
+        score[s] = similarity_score.flatten()
+        if c % 10 == 0:
+            print('Finish {}/{} pairs.'.format(c, total_sublists))
+    return score
+
+
+def main(args):
+    use_norm_score = True  # if Ture, TestMode(N1)
+    use_detector_score = True  # if Ture, TestMode(D1)
+    use_flip_test = True  # if Ture, TestMode(F1)
+    assert args.target == 'IJBC' or args.target == 'IJBB'
+
+    start = timeit.default_timer()
+    templates, medias = read_template_media_list(
+        os.path.join('%s/meta' % args.image_path, '%s_face_tid_mid.txt' % args.target.lower()))
+    stop = timeit.default_timer()
+    print('Time: %.2f s. ' % (stop - start))
+
+    start = timeit.default_timer()
+    p1, p2, label = read_template_pair_list(
+        os.path.join('%s/meta' % args.image_path,
+                     '%s_template_pair_label.txt' % args.target.lower()))
+    stop = timeit.default_timer()
+    print('Time: %.2f s. ' % (stop - start))
+
+    start = timeit.default_timer()
+    img_path = '%s/loose_crop' % args.image_path
+    img_list_path = '%s/meta/%s_name_5pts_score.txt' % (args.image_path, args.target.lower())
+    img_list = open(img_list_path)
+    files = img_list.readlines()
+    dataset = AlignedDataSet(root=img_path, lines=files, align=True)
+    img_feats = extract(args.model_root, dataset)
+
+    faceness_scores = []
+    for each_line in files:
+        name_lmk_score = each_line.split()
+        faceness_scores.append(name_lmk_score[-1])
+    faceness_scores = np.array(faceness_scores).astype(np.float32)
+    stop = timeit.default_timer()
+    print('Time: %.2f s. ' % (stop - start))
+    print('Feature Shape: ({} , {}) .'.format(img_feats.shape[0], img_feats.shape[1]))
+    start = timeit.default_timer()
+
+    if use_flip_test:
+        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2] + img_feats[:, img_feats.shape[1] // 2:]
+    else:
+        img_input_feats = img_feats[:, 0:img_feats.shape[1] // 2]
+
+    if use_norm_score:
+        img_input_feats = img_input_feats
+    else:
+        img_input_feats = img_input_feats / np.sqrt(np.sum(img_input_feats ** 2, -1, keepdims=True))
+
+    if use_detector_score:
+        print(img_input_feats.shape, faceness_scores.shape)
+        img_input_feats = img_input_feats * faceness_scores[:, np.newaxis]
+    else:
+        img_input_feats = img_input_feats
+
+    template_norm_feats, unique_templates = image2template_feature(
+        img_input_feats, templates, medias)
+    stop = timeit.default_timer()
+    print('Time: %.2f s. ' % (stop - start))
+
+    start = timeit.default_timer()
+    score = verification(template_norm_feats, unique_templates, p1, p2)
+    stop = timeit.default_timer()
+    print('Time: %.2f s. ' % (stop - start))
+    save_path = os.path.join(args.result_dir, "{}_result".format(args.target))
+    if not os.path.exists(save_path):
+        os.makedirs(save_path)
+    score_save_file = os.path.join(save_path, "{}.npy".format(args.model_root))
+    np.save(score_save_file, score)
+    files = [score_save_file]
+    methods = []
+    scores = []
+    for file in files:
+        methods.append(os.path.basename(file))
+        scores.append(np.load(file))
+    methods = np.array(methods)
+    scores = dict(zip(methods, scores))
+    x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+    tpr_fpr_table = prettytable.PrettyTable(['Methods'] + [str(x) for x in x_labels])
+    for method in methods:
+        fpr, tpr, _ = roc_curve(label, scores[method])
+        fpr = np.flipud(fpr)
+        tpr = np.flipud(tpr)
+        tpr_fpr_row = []
+        tpr_fpr_row.append("%s-%s" % (method, args.target))
+        for fpr_iter in np.arange(len(x_labels)):
+            _, min_index = min(
+                list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+            tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+        tpr_fpr_table.add_row(tpr_fpr_row)
+    print(tpr_fpr_table)
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='do ijb test')
+    # general
+    parser.add_argument('--model-root', default='', help='path to load model.')
+    parser.add_argument('--image-path', default='', type=str, help='')
+    parser.add_argument('--result-dir', default='.', type=str, help='')
+    parser.add_argument('--target', default='IJBC', type=str, help='target, set to IJBC or IJBB')
+    main(parser.parse_args())
diff --git a/src/face3d/models/arcface_torch/partial_fc.py b/src/face3d/models/arcface_torch/partial_fc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0286dd437319c920ecb61f4eb3a32333dcf49eb
--- /dev/null
+++ b/src/face3d/models/arcface_torch/partial_fc.py
@@ -0,0 +1,222 @@
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+from torch.nn import Module
+from torch.nn.functional import normalize, linear
+from torch.nn.parameter import Parameter
+
+
+class PartialFC(Module):
+    """
+    Author: {Xiang An, Yang Xiao, XuHan Zhu} in DeepGlint,
+    Partial FC: Training 10 Million Identities on a Single Machine
+    See the original paper:
+    https://arxiv.org/abs/2010.05222
+    """
+
+    @torch.no_grad()
+    def __init__(self, rank, local_rank, world_size, batch_size, resume,
+                 margin_softmax, num_classes, sample_rate=1.0, embedding_size=512, prefix="./"):
+        """
+        rank: int
+            Unique process(GPU) ID from 0 to world_size - 1.
+        local_rank: int
+            Unique process(GPU) ID within the server from 0 to 7.
+        world_size: int
+            Number of GPU.
+        batch_size: int
+            Batch size on current rank(GPU).
+        resume: bool
+            Select whether to restore the weight of softmax.
+        margin_softmax: callable
+            A function of margin softmax, eg: cosface, arcface.
+        num_classes: int
+            The number of class center storage in current rank(CPU/GPU), usually is total_classes // world_size,
+            required.
+        sample_rate: float
+            The partial fc sampling rate, when the number of classes increases to more than 2 millions, Sampling
+            can greatly speed up training, and reduce a lot of GPU memory, default is 1.0.
+        embedding_size: int
+            The feature dimension, default is 512.
+        prefix: str
+            Path for save checkpoint, default is './'.
+        """
+        super(PartialFC, self).__init__()
+        #
+        self.num_classes: int = num_classes
+        self.rank: int = rank
+        self.local_rank: int = local_rank
+        self.device: torch.device = torch.device("cuda:{}".format(self.local_rank))
+        self.world_size: int = world_size
+        self.batch_size: int = batch_size
+        self.margin_softmax: callable = margin_softmax
+        self.sample_rate: float = sample_rate
+        self.embedding_size: int = embedding_size
+        self.prefix: str = prefix
+        self.num_local: int = num_classes // world_size + int(rank < num_classes % world_size)
+        self.class_start: int = num_classes // world_size * rank + min(rank, num_classes % world_size)
+        self.num_sample: int = int(self.sample_rate * self.num_local)
+
+        self.weight_name = os.path.join(self.prefix, "rank_{}_softmax_weight.pt".format(self.rank))
+        self.weight_mom_name = os.path.join(self.prefix, "rank_{}_softmax_weight_mom.pt".format(self.rank))
+
+        if resume:
+            try:
+                self.weight: torch.Tensor = torch.load(self.weight_name)
+                self.weight_mom: torch.Tensor = torch.load(self.weight_mom_name)
+                if self.weight.shape[0] != self.num_local or self.weight_mom.shape[0] != self.num_local:
+                    raise IndexError
+                logging.info("softmax weight resume successfully!")
+                logging.info("softmax weight mom resume successfully!")
+            except (FileNotFoundError, KeyError, IndexError):
+                self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+                self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+                logging.info("softmax weight init!")
+                logging.info("softmax weight mom init!")
+        else:
+            self.weight = torch.normal(0, 0.01, (self.num_local, self.embedding_size), device=self.device)
+            self.weight_mom: torch.Tensor = torch.zeros_like(self.weight)
+            logging.info("softmax weight init successfully!")
+            logging.info("softmax weight mom init successfully!")
+        self.stream: torch.cuda.Stream = torch.cuda.Stream(local_rank)
+
+        self.index = None
+        if int(self.sample_rate) == 1:
+            self.update = lambda: 0
+            self.sub_weight = Parameter(self.weight)
+            self.sub_weight_mom = self.weight_mom
+        else:
+            self.sub_weight = Parameter(torch.empty((0, 0)).cuda(local_rank))
+
+    def save_params(self):
+        """ Save softmax weight for each rank on prefix
+        """
+        torch.save(self.weight.data, self.weight_name)
+        torch.save(self.weight_mom, self.weight_mom_name)
+
+    @torch.no_grad()
+    def sample(self, total_label):
+        """
+        Sample all positive class centers in each rank, and random select neg class centers to filling a fixed
+        `num_sample`.
+
+        total_label: tensor
+            Label after all gather, which cross all GPUs.
+        """
+        index_positive = (self.class_start <= total_label) & (total_label < self.class_start + self.num_local)
+        total_label[~index_positive] = -1
+        total_label[index_positive] -= self.class_start
+        if int(self.sample_rate) != 1:
+            positive = torch.unique(total_label[index_positive], sorted=True)
+            if self.num_sample - positive.size(0) >= 0:
+                perm = torch.rand(size=[self.num_local], device=self.device)
+                perm[positive] = 2.0
+                index = torch.topk(perm, k=self.num_sample)[1]
+                index = index.sort()[0]
+            else:
+                index = positive
+            self.index = index
+            total_label[index_positive] = torch.searchsorted(index, total_label[index_positive])
+            self.sub_weight = Parameter(self.weight[index])
+            self.sub_weight_mom = self.weight_mom[index]
+
+    def forward(self, total_features, norm_weight):
+        """ Partial fc forward, `logits = X * sample(W)`
+        """
+        torch.cuda.current_stream().wait_stream(self.stream)
+        logits = linear(total_features, norm_weight)
+        return logits
+
+    @torch.no_grad()
+    def update(self):
+        """ Set updated weight and weight_mom to memory bank.
+        """
+        self.weight_mom[self.index] = self.sub_weight_mom
+        self.weight[self.index] = self.sub_weight
+
+    def prepare(self, label, optimizer):
+        """
+        get sampled class centers for cal softmax.
+
+        label: tensor
+            Label tensor on each rank.
+        optimizer: opt
+            Optimizer for partial fc, which need to get weight mom.
+        """
+        with torch.cuda.stream(self.stream):
+            total_label = torch.zeros(
+                size=[self.batch_size * self.world_size], device=self.device, dtype=torch.long)
+            dist.all_gather(list(total_label.chunk(self.world_size, dim=0)), label)
+            self.sample(total_label)
+            optimizer.state.pop(optimizer.param_groups[-1]['params'][0], None)
+            optimizer.param_groups[-1]['params'][0] = self.sub_weight
+            optimizer.state[self.sub_weight]['momentum_buffer'] = self.sub_weight_mom
+            norm_weight = normalize(self.sub_weight)
+            return total_label, norm_weight
+
+    def forward_backward(self, label, features, optimizer):
+        """
+        Partial fc forward and backward with model parallel
+
+        label: tensor
+            Label tensor on each rank(GPU)
+        features: tensor
+            Features tensor on each rank(GPU)
+        optimizer: optimizer
+            Optimizer for partial fc
+
+        Returns:
+        --------
+        x_grad: tensor
+            The gradient of features.
+        loss_v: tensor
+            Loss value for cross entropy.
+        """
+        total_label, norm_weight = self.prepare(label, optimizer)
+        total_features = torch.zeros(
+            size=[self.batch_size * self.world_size, self.embedding_size], device=self.device)
+        dist.all_gather(list(total_features.chunk(self.world_size, dim=0)), features.data)
+        total_features.requires_grad = True
+
+        logits = self.forward(total_features, norm_weight)
+        logits = self.margin_softmax(logits, total_label)
+
+        with torch.no_grad():
+            max_fc = torch.max(logits, dim=1, keepdim=True)[0]
+            dist.all_reduce(max_fc, dist.ReduceOp.MAX)
+
+            # calculate exp(logits) and all-reduce
+            logits_exp = torch.exp(logits - max_fc)
+            logits_sum_exp = logits_exp.sum(dim=1, keepdims=True)
+            dist.all_reduce(logits_sum_exp, dist.ReduceOp.SUM)
+
+            # calculate prob
+            logits_exp.div_(logits_sum_exp)
+
+            # get one-hot
+            grad = logits_exp
+            index = torch.where(total_label != -1)[0]
+            one_hot = torch.zeros(size=[index.size()[0], grad.size()[1]], device=grad.device)
+            one_hot.scatter_(1, total_label[index, None], 1)
+
+            # calculate loss
+            loss = torch.zeros(grad.size()[0], 1, device=grad.device)
+            loss[index] = grad[index].gather(1, total_label[index, None])
+            dist.all_reduce(loss, dist.ReduceOp.SUM)
+            loss_v = loss.clamp_min_(1e-30).log_().mean() * (-1)
+
+            # calculate grad
+            grad[index] -= one_hot
+            grad.div_(self.batch_size * self.world_size)
+
+        logits.backward(grad)
+        if total_features.grad is not None:
+            total_features.grad.detach_()
+        x_grad: torch.Tensor = torch.zeros_like(features, requires_grad=True)
+        # feature gradient all-reduce
+        dist.reduce_scatter(x_grad, list(total_features.grad.chunk(self.world_size, dim=0)))
+        x_grad = x_grad * self.world_size
+        # backward backbone
+        return x_grad, loss_v
diff --git a/src/face3d/models/arcface_torch/requirement.txt b/src/face3d/models/arcface_torch/requirement.txt
new file mode 100644
index 0000000000000000000000000000000000000000..99aef673e30b99cbe56ce82a564c1df9df24ba21
--- /dev/null
+++ b/src/face3d/models/arcface_torch/requirement.txt
@@ -0,0 +1,5 @@
+tensorboard
+easydict
+mxnet
+onnx
+sklearn
diff --git a/src/face3d/models/arcface_torch/run.sh b/src/face3d/models/arcface_torch/run.sh
new file mode 100644
index 0000000000000000000000000000000000000000..67b25fd63ef3921733d81d5be844aacc5a5c84ed
--- /dev/null
+++ b/src/face3d/models/arcface_torch/run.sh
@@ -0,0 +1,2 @@
+CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
+ps -ef | grep "train" | grep -v grep | awk '{print "kill -9 "$2}' | sh
diff --git a/src/face3d/models/arcface_torch/torch2onnx.py b/src/face3d/models/arcface_torch/torch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..458660df7cc7f9a567aaf492c45f232e776a9ef0
--- /dev/null
+++ b/src/face3d/models/arcface_torch/torch2onnx.py
@@ -0,0 +1,59 @@
+import numpy as np
+import onnx
+import torch
+
+
+def convert_onnx(net, path_module, output, opset=11, simplify=False):
+    assert isinstance(net, torch.nn.Module)
+    img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32)
+    img = img.astype(np.float)
+    img = (img / 255. - 0.5) / 0.5  # torch style norm
+    img = img.transpose((2, 0, 1))
+    img = torch.from_numpy(img).unsqueeze(0).float()
+
+    weight = torch.load(path_module)
+    net.load_state_dict(weight)
+    net.eval()
+    torch.onnx.export(net, img, output, keep_initializers_as_inputs=False, verbose=False, opset_version=opset)
+    model = onnx.load(output)
+    graph = model.graph
+    graph.input[0].type.tensor_type.shape.dim[0].dim_param = 'None'
+    if simplify:
+        from onnxsim import simplify
+        model, check = simplify(model)
+        assert check, "Simplified ONNX model could not be validated"
+    onnx.save(model, output)
+
+    
+if __name__ == '__main__':
+    import os
+    import argparse
+    from backbones import get_model
+
+    parser = argparse.ArgumentParser(description='ArcFace PyTorch to onnx')
+    parser.add_argument('input', type=str, help='input backbone.pth file or path')
+    parser.add_argument('--output', type=str, default=None, help='output onnx path')
+    parser.add_argument('--network', type=str, default=None, help='backbone network')
+    parser.add_argument('--simplify', type=bool, default=False, help='onnx simplify')
+    args = parser.parse_args()
+    input_file = args.input
+    if os.path.isdir(input_file):
+        input_file = os.path.join(input_file, "backbone.pth")
+    assert os.path.exists(input_file)
+    model_name = os.path.basename(os.path.dirname(input_file)).lower()
+    params = model_name.split("_")
+    if len(params) >= 3 and params[1] in ('arcface', 'cosface'):
+        if args.network is None:
+            args.network = params[2]
+    assert args.network is not None
+    print(args)
+    backbone_onnx = get_model(args.network, dropout=0)
+
+    output_path = args.output
+    if output_path is None:
+        output_path = os.path.join(os.path.dirname(__file__), 'onnx')
+    if not os.path.exists(output_path):
+        os.makedirs(output_path)
+    assert os.path.isdir(output_path)
+    output_file = os.path.join(output_path, "%s.onnx" % model_name)
+    convert_onnx(backbone_onnx, input_file, output_file, simplify=args.simplify)
diff --git a/src/face3d/models/arcface_torch/train.py b/src/face3d/models/arcface_torch/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c5491de9af8fc7a2f3d0648c53b89584864f20e
--- /dev/null
+++ b/src/face3d/models/arcface_torch/train.py
@@ -0,0 +1,141 @@
+import argparse
+import logging
+import os
+
+import torch
+import torch.distributed as dist
+import torch.nn.functional as F
+import torch.utils.data.distributed
+from torch.nn.utils import clip_grad_norm_
+
+import losses
+from backbones import get_model
+from dataset import MXFaceDataset, SyntheticDataset, DataLoaderX
+from partial_fc import PartialFC
+from utils.utils_amp import MaxClipGradScaler
+from utils.utils_callbacks import CallBackVerification, CallBackLogging, CallBackModelCheckpoint
+from utils.utils_config import get_config
+from utils.utils_logging import AverageMeter, init_logging
+
+
+def main(args):
+    cfg = get_config(args.config)
+    try:
+        world_size = int(os.environ['WORLD_SIZE'])
+        rank = int(os.environ['RANK'])
+        dist.init_process_group('nccl')
+    except KeyError:
+        world_size = 1
+        rank = 0
+        dist.init_process_group(backend='nccl', init_method="tcp://127.0.0.1:12584", rank=rank, world_size=world_size)
+
+    local_rank = args.local_rank
+    torch.cuda.set_device(local_rank)
+    os.makedirs(cfg.output, exist_ok=True)
+    init_logging(rank, cfg.output)
+
+    if cfg.rec == "synthetic":
+        train_set = SyntheticDataset(local_rank=local_rank)
+    else:
+        train_set = MXFaceDataset(root_dir=cfg.rec, local_rank=local_rank)
+
+    train_sampler = torch.utils.data.distributed.DistributedSampler(train_set, shuffle=True)
+    train_loader = DataLoaderX(
+        local_rank=local_rank, dataset=train_set, batch_size=cfg.batch_size,
+        sampler=train_sampler, num_workers=2, pin_memory=True, drop_last=True)
+    backbone = get_model(cfg.network, dropout=0.0, fp16=cfg.fp16, num_features=cfg.embedding_size).to(local_rank)
+
+    if cfg.resume:
+        try:
+            backbone_pth = os.path.join(cfg.output, "backbone.pth")
+            backbone.load_state_dict(torch.load(backbone_pth, map_location=torch.device(local_rank)))
+            if rank == 0:
+                logging.info("backbone resume successfully!")
+        except (FileNotFoundError, KeyError, IndexError, RuntimeError):
+            if rank == 0:
+                logging.info("resume fail, backbone init successfully!")
+
+    backbone = torch.nn.parallel.DistributedDataParallel(
+        module=backbone, broadcast_buffers=False, device_ids=[local_rank])
+    backbone.train()
+    margin_softmax = losses.get_loss(cfg.loss)
+    module_partial_fc = PartialFC(
+        rank=rank, local_rank=local_rank, world_size=world_size, resume=cfg.resume,
+        batch_size=cfg.batch_size, margin_softmax=margin_softmax, num_classes=cfg.num_classes,
+        sample_rate=cfg.sample_rate, embedding_size=cfg.embedding_size, prefix=cfg.output)
+
+    opt_backbone = torch.optim.SGD(
+        params=[{'params': backbone.parameters()}],
+        lr=cfg.lr / 512 * cfg.batch_size * world_size,
+        momentum=0.9, weight_decay=cfg.weight_decay)
+    opt_pfc = torch.optim.SGD(
+        params=[{'params': module_partial_fc.parameters()}],
+        lr=cfg.lr / 512 * cfg.batch_size * world_size,
+        momentum=0.9, weight_decay=cfg.weight_decay)
+
+    num_image = len(train_set)
+    total_batch_size = cfg.batch_size * world_size
+    cfg.warmup_step = num_image // total_batch_size * cfg.warmup_epoch
+    cfg.total_step = num_image // total_batch_size * cfg.num_epoch
+
+    def lr_step_func(current_step):
+        cfg.decay_step = [x * num_image // total_batch_size for x in cfg.decay_epoch]
+        if current_step < cfg.warmup_step:
+            return current_step / cfg.warmup_step
+        else:
+            return 0.1 ** len([m for m in cfg.decay_step if m <= current_step])
+
+    scheduler_backbone = torch.optim.lr_scheduler.LambdaLR(
+        optimizer=opt_backbone, lr_lambda=lr_step_func)
+    scheduler_pfc = torch.optim.lr_scheduler.LambdaLR(
+        optimizer=opt_pfc, lr_lambda=lr_step_func)
+
+    for key, value in cfg.items():
+        num_space = 25 - len(key)
+        logging.info(": " + key + " " * num_space + str(value))
+
+    val_target = cfg.val_targets
+    callback_verification = CallBackVerification(2000, rank, val_target, cfg.rec)
+    callback_logging = CallBackLogging(50, rank, cfg.total_step, cfg.batch_size, world_size, None)
+    callback_checkpoint = CallBackModelCheckpoint(rank, cfg.output)
+
+    loss = AverageMeter()
+    start_epoch = 0
+    global_step = 0
+    grad_amp = MaxClipGradScaler(cfg.batch_size, 128 * cfg.batch_size, growth_interval=100) if cfg.fp16 else None
+    for epoch in range(start_epoch, cfg.num_epoch):
+        train_sampler.set_epoch(epoch)
+        for step, (img, label) in enumerate(train_loader):
+            global_step += 1
+            features = F.normalize(backbone(img))
+            x_grad, loss_v = module_partial_fc.forward_backward(label, features, opt_pfc)
+            if cfg.fp16:
+                features.backward(grad_amp.scale(x_grad))
+                grad_amp.unscale_(opt_backbone)
+                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+                grad_amp.step(opt_backbone)
+                grad_amp.update()
+            else:
+                features.backward(x_grad)
+                clip_grad_norm_(backbone.parameters(), max_norm=5, norm_type=2)
+                opt_backbone.step()
+
+            opt_pfc.step()
+            module_partial_fc.update()
+            opt_backbone.zero_grad()
+            opt_pfc.zero_grad()
+            loss.update(loss_v, 1)
+            callback_logging(global_step, loss, epoch, cfg.fp16, scheduler_backbone.get_last_lr()[0], grad_amp)
+            callback_verification(global_step, backbone)
+            scheduler_backbone.step()
+            scheduler_pfc.step()
+        callback_checkpoint(global_step, backbone, module_partial_fc)
+    dist.destroy_process_group()
+
+
+if __name__ == "__main__":
+    torch.backends.cudnn.benchmark = True
+    parser = argparse.ArgumentParser(description='PyTorch ArcFace Training')
+    parser.add_argument('config', type=str, help='py config file')
+    parser.add_argument('--local_rank', type=int, default=0, help='local_rank')
+    main(parser.parse_args())
diff --git a/src/face3d/models/arcface_torch/utils/__init__.py b/src/face3d/models/arcface_torch/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/arcface_torch/utils/plot.py b/src/face3d/models/arcface_torch/utils/plot.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fce6cc0ae526d5aebc8e7a1550300ceae3a2034
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/plot.py
@@ -0,0 +1,72 @@
+# coding: utf-8
+
+import os
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+from menpo.visualize.viewmatplotlib import sample_colours_from_colourmap
+from prettytable import PrettyTable
+from sklearn.metrics import roc_curve, auc
+
+image_path = "/data/anxiang/IJB_release/IJBC"
+files = [
+        "./ms1mv3_arcface_r100/ms1mv3_arcface_r100/ijbc.npy"
+]
+
+
+def read_template_pair_list(path):
+    pairs = pd.read_csv(path, sep=' ', header=None).values
+    t1 = pairs[:, 0].astype(np.int)
+    t2 = pairs[:, 1].astype(np.int)
+    label = pairs[:, 2].astype(np.int)
+    return t1, t2, label
+
+
+p1, p2, label = read_template_pair_list(
+    os.path.join('%s/meta' % image_path,
+                 '%s_template_pair_label.txt' % 'ijbc'))
+
+methods = []
+scores = []
+for file in files:
+    methods.append(file.split('/')[-2])
+    scores.append(np.load(file))
+
+methods = np.array(methods)
+scores = dict(zip(methods, scores))
+colours = dict(
+    zip(methods, sample_colours_from_colourmap(methods.shape[0], 'Set2')))
+x_labels = [10 ** -6, 10 ** -5, 10 ** -4, 10 ** -3, 10 ** -2, 10 ** -1]
+tpr_fpr_table = PrettyTable(['Methods'] + [str(x) for x in x_labels])
+fig = plt.figure()
+for method in methods:
+    fpr, tpr, _ = roc_curve(label, scores[method])
+    roc_auc = auc(fpr, tpr)
+    fpr = np.flipud(fpr)
+    tpr = np.flipud(tpr)  # select largest tpr at same fpr
+    plt.plot(fpr,
+             tpr,
+             color=colours[method],
+             lw=1,
+             label=('[%s (AUC = %0.4f %%)]' %
+                    (method.split('-')[-1], roc_auc * 100)))
+    tpr_fpr_row = []
+    tpr_fpr_row.append("%s-%s" % (method, "IJBC"))
+    for fpr_iter in np.arange(len(x_labels)):
+        _, min_index = min(
+            list(zip(abs(fpr - x_labels[fpr_iter]), range(len(fpr)))))
+        tpr_fpr_row.append('%.2f' % (tpr[min_index] * 100))
+    tpr_fpr_table.add_row(tpr_fpr_row)
+plt.xlim([10 ** -6, 0.1])
+plt.ylim([0.3, 1.0])
+plt.grid(linestyle='--', linewidth=1)
+plt.xticks(x_labels)
+plt.yticks(np.linspace(0.3, 1.0, 8, endpoint=True))
+plt.xscale('log')
+plt.xlabel('False Positive Rate')
+plt.ylabel('True Positive Rate')
+plt.title('ROC on IJB')
+plt.legend(loc="lower right")
+print(tpr_fpr_table)
diff --git a/src/face3d/models/arcface_torch/utils/utils_amp.py b/src/face3d/models/arcface_torch/utils/utils_amp.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d5bcbb540ff8b04535e71c0057e124338df5bd
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_amp.py
@@ -0,0 +1,88 @@
+from typing import Dict, List
+
+import torch
+
+if torch.__version__ < '1.9':
+    Iterable = torch._six.container_abcs.Iterable
+else:
+    import collections
+
+    Iterable = collections.abc.Iterable
+from torch.cuda.amp import GradScaler
+
+
+class _MultiDeviceReplicator(object):
+    """
+    Lazily serves copies of a tensor to requested devices.  Copies are cached per-device.
+    """
+
+    def __init__(self, master_tensor: torch.Tensor) -> None:
+        assert master_tensor.is_cuda
+        self.master = master_tensor
+        self._per_device_tensors: Dict[torch.device, torch.Tensor] = {}
+
+    def get(self, device) -> torch.Tensor:
+        retval = self._per_device_tensors.get(device, None)
+        if retval is None:
+            retval = self.master.to(device=device, non_blocking=True, copy=True)
+            self._per_device_tensors[device] = retval
+        return retval
+
+
+class MaxClipGradScaler(GradScaler):
+    def __init__(self, init_scale, max_scale: float, growth_interval=100):
+        GradScaler.__init__(self, init_scale=init_scale, growth_interval=growth_interval)
+        self.max_scale = max_scale
+
+    def scale_clip(self):
+        if self.get_scale() == self.max_scale:
+            self.set_growth_factor(1)
+        elif self.get_scale() < self.max_scale:
+            self.set_growth_factor(2)
+        elif self.get_scale() > self.max_scale:
+            self._scale.fill_(self.max_scale)
+            self.set_growth_factor(1)
+
+    def scale(self, outputs):
+        """
+        Multiplies ('scales') a tensor or list of tensors by the scale factor.
+
+        Returns scaled outputs.  If this instance of :class:`GradScaler` is not enabled, outputs are returned
+        unmodified.
+
+        Arguments:
+            outputs (Tensor or iterable of Tensors):  Outputs to scale.
+        """
+        if not self._enabled:
+            return outputs
+        self.scale_clip()
+        # Short-circuit for the common case.
+        if isinstance(outputs, torch.Tensor):
+            assert outputs.is_cuda
+            if self._scale is None:
+                self._lazy_init_scale_growth_tracker(outputs.device)
+            assert self._scale is not None
+            return outputs * self._scale.to(device=outputs.device, non_blocking=True)
+
+        # Invoke the more complex machinery only if we're treating multiple outputs.
+        stash: List[_MultiDeviceReplicator] = []  # holds a reference that can be overwritten by apply_scale
+
+        def apply_scale(val):
+            if isinstance(val, torch.Tensor):
+                assert val.is_cuda
+                if len(stash) == 0:
+                    if self._scale is None:
+                        self._lazy_init_scale_growth_tracker(val.device)
+                    assert self._scale is not None
+                    stash.append(_MultiDeviceReplicator(self._scale))
+                return val * stash[0].get(val.device)
+            elif isinstance(val, Iterable):
+                iterable = map(apply_scale, val)
+                if isinstance(val, list) or isinstance(val, tuple):
+                    return type(val)(iterable)
+                else:
+                    return iterable
+            else:
+                raise ValueError("outputs must be a Tensor or an iterable of Tensors")
+
+        return apply_scale(outputs)
diff --git a/src/face3d/models/arcface_torch/utils/utils_callbacks.py b/src/face3d/models/arcface_torch/utils/utils_callbacks.py
new file mode 100644
index 0000000000000000000000000000000000000000..748923b36358bd118efa0532a6f512b6ca96ff34
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_callbacks.py
@@ -0,0 +1,117 @@
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+from eval import verification
+from utils.utils_logging import AverageMeter
+
+
+class CallBackVerification(object):
+    def __init__(self, frequent, rank, val_targets, rec_prefix, image_size=(112, 112)):
+        self.frequent: int = frequent
+        self.rank: int = rank
+        self.highest_acc: float = 0.0
+        self.highest_acc_list: List[float] = [0.0] * len(val_targets)
+        self.ver_list: List[object] = []
+        self.ver_name_list: List[str] = []
+        if self.rank is 0:
+            self.init_dataset(val_targets=val_targets, data_dir=rec_prefix, image_size=image_size)
+
+    def ver_test(self, backbone: torch.nn.Module, global_step: int):
+        results = []
+        for i in range(len(self.ver_list)):
+            acc1, std1, acc2, std2, xnorm, embeddings_list = verification.test(
+                self.ver_list[i], backbone, 10, 10)
+            logging.info('[%s][%d]XNorm: %f' % (self.ver_name_list[i], global_step, xnorm))
+            logging.info('[%s][%d]Accuracy-Flip: %1.5f+-%1.5f' % (self.ver_name_list[i], global_step, acc2, std2))
+            if acc2 > self.highest_acc_list[i]:
+                self.highest_acc_list[i] = acc2
+            logging.info(
+                '[%s][%d]Accuracy-Highest: %1.5f' % (self.ver_name_list[i], global_step, self.highest_acc_list[i]))
+            results.append(acc2)
+
+    def init_dataset(self, val_targets, data_dir, image_size):
+        for name in val_targets:
+            path = os.path.join(data_dir, name + ".bin")
+            if os.path.exists(path):
+                data_set = verification.load_bin(path, image_size)
+                self.ver_list.append(data_set)
+                self.ver_name_list.append(name)
+
+    def __call__(self, num_update, backbone: torch.nn.Module):
+        if self.rank is 0 and num_update > 0 and num_update % self.frequent == 0:
+            backbone.eval()
+            self.ver_test(backbone, num_update)
+            backbone.train()
+
+
+class CallBackLogging(object):
+    def __init__(self, frequent, rank, total_step, batch_size, world_size, writer=None):
+        self.frequent: int = frequent
+        self.rank: int = rank
+        self.time_start = time.time()
+        self.total_step: int = total_step
+        self.batch_size: int = batch_size
+        self.world_size: int = world_size
+        self.writer = writer
+
+        self.init = False
+        self.tic = 0
+
+    def __call__(self,
+                 global_step: int,
+                 loss: AverageMeter,
+                 epoch: int,
+                 fp16: bool,
+                 learning_rate: float,
+                 grad_scaler: torch.cuda.amp.GradScaler):
+        if self.rank == 0 and global_step > 0 and global_step % self.frequent == 0:
+            if self.init:
+                try:
+                    speed: float = self.frequent * self.batch_size / (time.time() - self.tic)
+                    speed_total = speed * self.world_size
+                except ZeroDivisionError:
+                    speed_total = float('inf')
+
+                time_now = (time.time() - self.time_start) / 3600
+                time_total = time_now / ((global_step + 1) / self.total_step)
+                time_for_end = time_total - time_now
+                if self.writer is not None:
+                    self.writer.add_scalar('time_for_end', time_for_end, global_step)
+                    self.writer.add_scalar('learning_rate', learning_rate, global_step)
+                    self.writer.add_scalar('loss', loss.avg, global_step)
+                if fp16:
+                    msg = "Speed %.2f samples/sec   Loss %.4f   LearningRate %.4f   Epoch: %d   Global Step: %d   " \
+                          "Fp16 Grad Scale: %2.f   Required: %1.f hours" % (
+                              speed_total, loss.avg, learning_rate, epoch, global_step,
+                              grad_scaler.get_scale(), time_for_end
+                          )
+                else:
+                    msg = "Speed %.2f samples/sec   Loss %.4f   LearningRate %.4f   Epoch: %d   Global Step: %d   " \
+                          "Required: %1.f hours" % (
+                              speed_total, loss.avg, learning_rate, epoch, global_step, time_for_end
+                          )
+                logging.info(msg)
+                loss.reset()
+                self.tic = time.time()
+            else:
+                self.init = True
+                self.tic = time.time()
+
+
+class CallBackModelCheckpoint(object):
+    def __init__(self, rank, output="./"):
+        self.rank: int = rank
+        self.output: str = output
+
+    def __call__(self, global_step, backbone, partial_fc, ):
+        if global_step > 100 and self.rank == 0:
+            path_module = os.path.join(self.output, "backbone.pth")
+            torch.save(backbone.module.state_dict(), path_module)
+            logging.info("Pytorch Model Saved in '{}'".format(path_module))
+
+        if global_step > 100 and partial_fc is not None:
+            partial_fc.save_params()
diff --git a/src/face3d/models/arcface_torch/utils/utils_config.py b/src/face3d/models/arcface_torch/utils/utils_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b60a1e5a2e860ce5511a2d3863c8b57a4df292d7
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_config.py
@@ -0,0 +1,16 @@
+import importlib
+import os.path as osp
+
+
+def get_config(config_file):
+    assert config_file.startswith('configs/'), 'config file setting must start with configs/'
+    temp_config_name = osp.basename(config_file)
+    temp_module_name = osp.splitext(temp_config_name)[0]
+    config = importlib.import_module("configs.base")
+    cfg = config.config
+    config = importlib.import_module("configs.%s" % temp_module_name)
+    job_cfg = config.config
+    cfg.update(job_cfg)
+    if cfg.output is None:
+        cfg.output = osp.join('work_dirs', temp_module_name)
+    return cfg
\ No newline at end of file
diff --git a/src/face3d/models/arcface_torch/utils/utils_logging.py b/src/face3d/models/arcface_torch/utils/utils_logging.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2b43b851c9e06230abd94c73a1f64cfa1b6f3ac
--- /dev/null
+++ b/src/face3d/models/arcface_torch/utils/utils_logging.py
@@ -0,0 +1,41 @@
+import logging
+import os
+import sys
+
+
+class AverageMeter(object):
+    """Computes and stores the average and current value
+    """
+
+    def __init__(self):
+        self.val = None
+        self.avg = None
+        self.sum = None
+        self.count = None
+        self.reset()
+
+    def reset(self):
+        self.val = 0
+        self.avg = 0
+        self.sum = 0
+        self.count = 0
+
+    def update(self, val, n=1):
+        self.val = val
+        self.sum += val * n
+        self.count += n
+        self.avg = self.sum / self.count
+
+
+def init_logging(rank, models_root):
+    if rank == 0:
+        log_root = logging.getLogger()
+        log_root.setLevel(logging.INFO)
+        formatter = logging.Formatter("Training: %(asctime)s-%(message)s")
+        handler_file = logging.FileHandler(os.path.join(models_root, "training.log"))
+        handler_stream = logging.StreamHandler(sys.stdout)
+        handler_file.setFormatter(formatter)
+        handler_stream.setFormatter(formatter)
+        log_root.addHandler(handler_file)
+        log_root.addHandler(handler_stream)
+        log_root.info('rank_id: %d' % rank)
diff --git a/src/face3d/models/arcface_torch/utils/utils_os.py b/src/face3d/models/arcface_torch/utils/utils_os.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/src/face3d/models/base_model.py b/src/face3d/models/base_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b975223f6148febfe32d20d63980583c97b61eb3
--- /dev/null
+++ b/src/face3d/models/base_model.py
@@ -0,0 +1,316 @@
+"""This script defines the base network model for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch
+from collections import OrderedDict
+from abc import ABC, abstractmethod
+from . import networks
+
+
+class BaseModel(ABC):
+    """This class is an abstract base class (ABC) for models.
+    To create a subclass, you need to implement the following five functions:
+        -- <__init__>:                      initialize the class; first call BaseModel.__init__(self, opt).
+        -- <set_input>:                     unpack data from dataset and apply preprocessing.
+        -- <forward>:                       produce intermediate results.
+        -- <optimize_parameters>:           calculate losses, gradients, and update network weights.
+        -- <modify_commandline_options>:    (optionally) add model-specific options and set default options.
+    """
+
+    def __init__(self, opt):
+        """Initialize the BaseModel class.
+
+        Parameters:
+            opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
+
+        When creating your custom class, you need to implement your own initialization.
+        In this fucntion, you should first call <BaseModel.__init__(self, opt)>
+        Then, you need to define four lists:
+            -- self.loss_names (str list):          specify the training losses that you want to plot and save.
+            -- self.model_names (str list):         specify the images that you want to display and save.
+            -- self.visual_names (str list):        define networks used in our training.
+            -- self.optimizers (optimizer list):    define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+        """
+        self.opt = opt
+        self.isTrain = False
+        self.device = torch.device('cpu') 
+        self.save_dir = " " # os.path.join(opt.checkpoints_dir, opt.name)  # save all the checkpoints to save_dir
+        self.loss_names = []
+        self.model_names = []
+        self.visual_names = []
+        self.parallel_names = []
+        self.optimizers = []
+        self.image_paths = []
+        self.metric = 0  # used for learning rate policy 'plateau'
+
+    @staticmethod
+    def dict_grad_hook_factory(add_func=lambda x: x):
+        saved_dict = dict()
+
+        def hook_gen(name):
+            def grad_hook(grad):
+                saved_vals = add_func(grad)
+                saved_dict[name] = saved_vals
+            return grad_hook
+        return hook_gen, saved_dict
+
+    @staticmethod
+    def modify_commandline_options(parser, is_train):
+        """Add new model-specific options, and rewrite default values for existing options.
+
+        Parameters:
+            parser          -- original option parser
+            is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+        Returns:
+            the modified parser.
+        """
+        return parser
+
+    @abstractmethod
+    def set_input(self, input):
+        """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+        Parameters:
+            input (dict): includes the data itself and its metadata information.
+        """
+        pass
+
+    @abstractmethod
+    def forward(self):
+        """Run forward pass; called by both functions <optimize_parameters> and <test>."""
+        pass
+
+    @abstractmethod
+    def optimize_parameters(self):
+        """Calculate losses, gradients, and update network weights; called in every training iteration"""
+        pass
+
+    def setup(self, opt):
+        """Load and print networks; create schedulers
+
+        Parameters:
+            opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
+        """
+        if self.isTrain:
+            self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
+        
+        if not self.isTrain or opt.continue_train:
+            load_suffix = opt.epoch
+            self.load_networks(load_suffix)
+ 
+            
+        # self.print_networks(opt.verbose)
+
+    def parallelize(self, convert_sync_batchnorm=True):
+        if not self.opt.use_ddp:
+            for name in self.parallel_names:
+                if isinstance(name, str):
+                    module = getattr(self, name)
+                    setattr(self, name, module.to(self.device))
+        else:
+            for name in self.model_names:
+                if isinstance(name, str):
+                    module = getattr(self, name)
+                    if convert_sync_batchnorm:
+                        module = torch.nn.SyncBatchNorm.convert_sync_batchnorm(module)
+                    setattr(self, name, torch.nn.parallel.DistributedDataParallel(module.to(self.device),
+                        device_ids=[self.device.index], 
+                        find_unused_parameters=True, broadcast_buffers=True))
+            
+            # DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient.
+            for name in self.parallel_names:
+                if isinstance(name, str) and name not in self.model_names:
+                    module = getattr(self, name)
+                    setattr(self, name, module.to(self.device))
+            
+        # put state_dict of optimizer to gpu device
+        if self.opt.phase != 'test':
+            if self.opt.continue_train:
+                for optim in self.optimizers:
+                    for state in optim.state.values():
+                        for k, v in state.items():
+                            if isinstance(v, torch.Tensor):
+                                state[k] = v.to(self.device)
+
+    def data_dependent_initialize(self, data):
+        pass
+
+    def train(self):
+        """Make models train mode"""
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, name)
+                net.train()
+
+    def eval(self):
+        """Make models eval mode"""
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, name)
+                net.eval()
+
+    def test(self):
+        """Forward function used in test time.
+
+        This function wraps <forward> function in no_grad() so we don't save intermediate steps for backprop
+        It also calls <compute_visuals> to produce additional visualization results
+        """
+        with torch.no_grad():
+            self.forward()
+            self.compute_visuals()
+
+    def compute_visuals(self):
+        """Calculate additional output images for visdom and HTML visualization"""
+        pass
+
+    def get_image_paths(self, name='A'):
+        """ Return image paths that are used to load current data"""
+        return self.image_paths if name =='A' else self.image_paths_B
+
+    def update_learning_rate(self):
+        """Update learning rates for all the networks; called at the end of every epoch"""
+        for scheduler in self.schedulers:
+            if self.opt.lr_policy == 'plateau':
+                scheduler.step(self.metric)
+            else:
+                scheduler.step()
+
+        lr = self.optimizers[0].param_groups[0]['lr']
+        print('learning rate = %.7f' % lr)
+
+    def get_current_visuals(self):
+        """Return visualization images. train.py will display these images with visdom, and save the images to a HTML"""
+        visual_ret = OrderedDict()
+        for name in self.visual_names:
+            if isinstance(name, str):
+                visual_ret[name] = getattr(self, name)[:, :3, ...]
+        return visual_ret
+
+    def get_current_losses(self):
+        """Return traning losses / errors. train.py will print out these errors on console, and save them to a file"""
+        errors_ret = OrderedDict()
+        for name in self.loss_names:
+            if isinstance(name, str):
+                errors_ret[name] = float(getattr(self, 'loss_' + name))  # float(...) works for both scalar tensor and float number
+        return errors_ret
+
+    def save_networks(self, epoch):
+        """Save all the networks to the disk.
+
+        Parameters:
+            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+        """
+        if not os.path.isdir(self.save_dir):
+            os.makedirs(self.save_dir)
+
+        save_filename = 'epoch_%s.pth' % (epoch)
+        save_path = os.path.join(self.save_dir, save_filename)
+        
+        save_dict = {}
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, name)
+                if isinstance(net, torch.nn.DataParallel) or isinstance(net,
+                        torch.nn.parallel.DistributedDataParallel):
+                    net = net.module
+                save_dict[name] = net.state_dict()
+                
+
+        for i, optim in enumerate(self.optimizers):
+            save_dict['opt_%02d'%i] = optim.state_dict()
+
+        for i, sched in enumerate(self.schedulers):
+            save_dict['sched_%02d'%i] = sched.state_dict()
+        
+        torch.save(save_dict, save_path)
+
+    def __patch_instance_norm_state_dict(self, state_dict, module, keys, i=0):
+        """Fix InstanceNorm checkpoints incompatibility (prior to 0.4)"""
+        key = keys[i]
+        if i + 1 == len(keys):  # at the end, pointing to a parameter/buffer
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+                    (key == 'running_mean' or key == 'running_var'):
+                if getattr(module, key) is None:
+                    state_dict.pop('.'.join(keys))
+            if module.__class__.__name__.startswith('InstanceNorm') and \
+               (key == 'num_batches_tracked'):
+                state_dict.pop('.'.join(keys))
+        else:
+            self.__patch_instance_norm_state_dict(state_dict, getattr(module, key), keys, i + 1)
+
+    def load_networks(self, epoch):
+        """Load all the networks from the disk.
+
+        Parameters:
+            epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
+        """
+        if self.opt.isTrain and self.opt.pretrained_name is not None:
+            load_dir = os.path.join(self.opt.checkpoints_dir, self.opt.pretrained_name)
+        else:
+            load_dir = self.save_dir    
+        load_filename = 'epoch_%s.pth' % (epoch)
+        load_path = os.path.join(load_dir, load_filename)
+        state_dict = torch.load(load_path, map_location=self.device)
+        print('loading the model from %s' % load_path)
+
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, name)
+                if isinstance(net, torch.nn.DataParallel):
+                    net = net.module
+                net.load_state_dict(state_dict[name])
+        
+        if self.opt.phase != 'test':
+            if self.opt.continue_train:
+                print('loading the optim from %s' % load_path)
+                for i, optim in enumerate(self.optimizers):
+                    optim.load_state_dict(state_dict['opt_%02d'%i])
+
+                try:
+                    print('loading the sched from %s' % load_path)
+                    for i, sched in enumerate(self.schedulers):
+                        sched.load_state_dict(state_dict['sched_%02d'%i])
+                except:
+                    print('Failed to load schedulers, set schedulers according to epoch count manually')
+                    for i, sched in enumerate(self.schedulers):
+                        sched.last_epoch = self.opt.epoch_count - 1
+                    
+
+            
+
+    def print_networks(self, verbose):
+        """Print the total number of parameters in the network and (if verbose) network architecture
+
+        Parameters:
+            verbose (bool) -- if verbose: print the network architecture
+        """
+        print('---------- Networks initialized -------------')
+        for name in self.model_names:
+            if isinstance(name, str):
+                net = getattr(self, name)
+                num_params = 0
+                for param in net.parameters():
+                    num_params += param.numel()
+                if verbose:
+                    print(net)
+                print('[Network %s] Total number of parameters : %.3f M' % (name, num_params / 1e6))
+        print('-----------------------------------------------')
+
+    def set_requires_grad(self, nets, requires_grad=False):
+        """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
+        Parameters:
+            nets (network list)   -- a list of networks
+            requires_grad (bool)  -- whether the networks require gradients or not
+        """
+        if not isinstance(nets, list):
+            nets = [nets]
+        for net in nets:
+            if net is not None:
+                for param in net.parameters():
+                    param.requires_grad = requires_grad
+
+    def generate_visuals_for_evaluation(self, data, mode):
+        return {}
diff --git a/src/face3d/models/bfm.py b/src/face3d/models/bfm.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cecaf589befac790cf9c124737ba01e27bc29e6
--- /dev/null
+++ b/src/face3d/models/bfm.py
@@ -0,0 +1,331 @@
+"""This script defines the parametric 3d face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import  torch
+import torch.nn.functional as F
+from scipy.io import loadmat
+from src.face3d.util.load_mats import transferBFM09
+import os
+
+def perspective_projection(focal, center):
+    # return p.T (N, 3) @ (3, 3) 
+    return np.array([
+        focal, 0, center,
+        0, focal, center,
+        0, 0, 1
+    ]).reshape([3, 3]).astype(np.float32).transpose()
+
+class SH:
+    def __init__(self):
+        self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
+        self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
+
+
+
+class ParametricFaceModel:
+    def __init__(self, 
+                bfm_folder='./BFM', 
+                recenter=True,
+                camera_distance=10.,
+                init_lit=np.array([
+                    0.8, 0, 0, 0, 0, 0, 0, 0, 0
+                    ]),
+                focal=1015.,
+                center=112.,
+                is_train=True,
+                default_name='BFM_model_front.mat'):
+        
+        if not os.path.isfile(os.path.join(bfm_folder, default_name)):
+            transferBFM09(bfm_folder)
+            
+        model = loadmat(os.path.join(bfm_folder, default_name))
+        # mean face shape. [3*N,1]
+        self.mean_shape = model['meanshape'].astype(np.float32)
+        # identity basis. [3*N,80]
+        self.id_base = model['idBase'].astype(np.float32)
+        # expression basis. [3*N,64]
+        self.exp_base = model['exBase'].astype(np.float32)
+        # mean face texture. [3*N,1] (0-255)
+        self.mean_tex = model['meantex'].astype(np.float32)
+        # texture basis. [3*N,80]
+        self.tex_base = model['texBase'].astype(np.float32)
+        # face indices for each vertex that lies in. starts from 0. [N,8]
+        self.point_buf = model['point_buf'].astype(np.int64) - 1
+        # vertex indices for each face. starts from 0. [F,3]
+        self.face_buf = model['tri'].astype(np.int64) - 1
+        # vertex indices for 68 landmarks. starts from 0. [68,1]
+        self.keypoints = np.squeeze(model['keypoints']).astype(np.int64) - 1
+
+        if is_train:
+            # vertex indices for small face region to compute photometric error. starts from 0.
+            self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
+            # vertex indices for each face from small face region. starts from 0. [f,3]
+            self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
+            # vertex indices for pre-defined skin region to compute reflectance loss
+            self.skin_mask = np.squeeze(model['skinmask'])
+        
+        if recenter:
+            mean_shape = self.mean_shape.reshape([-1, 3])
+            mean_shape = mean_shape - np.mean(mean_shape, axis=0, keepdims=True)
+            self.mean_shape = mean_shape.reshape([-1, 1])
+
+        self.persc_proj = perspective_projection(focal, center)
+        self.device = 'cpu'
+        self.camera_distance = camera_distance
+        self.SH = SH()
+        self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
+        
+
+    def to(self, device):
+        self.device = device
+        for key, value in self.__dict__.items():
+            if type(value).__module__ == np.__name__:
+                setattr(self, key, torch.tensor(value).to(device))
+
+    
+    def compute_shape(self, id_coeff, exp_coeff):
+        """
+        Return:
+            face_shape       -- torch.tensor, size (B, N, 3)
+
+        Parameters:
+            id_coeff         -- torch.tensor, size (B, 80), identity coeffs
+            exp_coeff        -- torch.tensor, size (B, 64), expression coeffs
+        """
+        batch_size = id_coeff.shape[0]
+        id_part = torch.einsum('ij,aj->ai', self.id_base, id_coeff)
+        exp_part = torch.einsum('ij,aj->ai', self.exp_base, exp_coeff)
+        face_shape = id_part + exp_part + self.mean_shape.reshape([1, -1])
+        return face_shape.reshape([batch_size, -1, 3])
+    
+
+    def compute_texture(self, tex_coeff, normalize=True):
+        """
+        Return:
+            face_texture     -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
+
+        Parameters:
+            tex_coeff        -- torch.tensor, size (B, 80)
+        """
+        batch_size = tex_coeff.shape[0]
+        face_texture = torch.einsum('ij,aj->ai', self.tex_base, tex_coeff) + self.mean_tex
+        if normalize:
+            face_texture = face_texture / 255.
+        return face_texture.reshape([batch_size, -1, 3])
+
+
+    def compute_norm(self, face_shape):
+        """
+        Return:
+            vertex_norm      -- torch.tensor, size (B, N, 3)
+
+        Parameters:
+            face_shape       -- torch.tensor, size (B, N, 3)
+        """
+
+        v1 = face_shape[:, self.face_buf[:, 0]]
+        v2 = face_shape[:, self.face_buf[:, 1]]
+        v3 = face_shape[:, self.face_buf[:, 2]]
+        e1 = v1 - v2
+        e2 = v2 - v3
+        face_norm = torch.cross(e1, e2, dim=-1)
+        face_norm = F.normalize(face_norm, dim=-1, p=2)
+        face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.device)], dim=1)
+        
+        vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
+        vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
+        return vertex_norm
+
+
+    def compute_color(self, face_texture, face_norm, gamma):
+        """
+        Return:
+            face_color       -- torch.tensor, size (B, N, 3), range (0, 1.)
+
+        Parameters:
+            face_texture     -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
+            face_norm        -- torch.tensor, size (B, N, 3), rotated face normal
+            gamma            -- torch.tensor, size (B, 27), SH coeffs
+        """
+        batch_size = gamma.shape[0]
+        v_num = face_texture.shape[1]
+        a, c = self.SH.a, self.SH.c
+        gamma = gamma.reshape([batch_size, 3, 9])
+        gamma = gamma + self.init_lit
+        gamma = gamma.permute(0, 2, 1)
+        Y = torch.cat([
+             a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.device),
+            -a[1] * c[1] * face_norm[..., 1:2],
+             a[1] * c[1] * face_norm[..., 2:],
+            -a[1] * c[1] * face_norm[..., :1],
+             a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
+            -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
+            0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
+            -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
+            0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2  - face_norm[..., 1:2] ** 2)
+        ], dim=-1)
+        r = Y @ gamma[..., :1]
+        g = Y @ gamma[..., 1:2]
+        b = Y @ gamma[..., 2:]
+        face_color = torch.cat([r, g, b], dim=-1) * face_texture
+        return face_color
+
+    
+    def compute_rotation(self, angles):
+        """
+        Return:
+            rot              -- torch.tensor, size (B, 3, 3) pts @ trans_mat
+
+        Parameters:
+            angles           -- torch.tensor, size (B, 3), radian
+        """
+
+        batch_size = angles.shape[0]
+        ones = torch.ones([batch_size, 1]).to(self.device)
+        zeros = torch.zeros([batch_size, 1]).to(self.device)
+        x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
+        
+        rot_x = torch.cat([
+            ones, zeros, zeros,
+            zeros, torch.cos(x), -torch.sin(x), 
+            zeros, torch.sin(x), torch.cos(x)
+        ], dim=1).reshape([batch_size, 3, 3])
+        
+        rot_y = torch.cat([
+            torch.cos(y), zeros, torch.sin(y),
+            zeros, ones, zeros,
+            -torch.sin(y), zeros, torch.cos(y)
+        ], dim=1).reshape([batch_size, 3, 3])
+
+        rot_z = torch.cat([
+            torch.cos(z), -torch.sin(z), zeros,
+            torch.sin(z), torch.cos(z), zeros,
+            zeros, zeros, ones
+        ], dim=1).reshape([batch_size, 3, 3])
+
+        rot = rot_z @ rot_y @ rot_x
+        return rot.permute(0, 2, 1)
+
+
+    def to_camera(self, face_shape):
+        face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
+        return face_shape
+
+    def to_image(self, face_shape):
+        """
+        Return:
+            face_proj        -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
+
+        Parameters:
+            face_shape       -- torch.tensor, size (B, N, 3)
+        """
+        # to image_plane
+        face_proj = face_shape @ self.persc_proj
+        face_proj = face_proj[..., :2] / face_proj[..., 2:]
+
+        return face_proj
+
+
+    def transform(self, face_shape, rot, trans):
+        """
+        Return:
+            face_shape       -- torch.tensor, size (B, N, 3) pts @ rot + trans
+
+        Parameters:
+            face_shape       -- torch.tensor, size (B, N, 3)
+            rot              -- torch.tensor, size (B, 3, 3)
+            trans            -- torch.tensor, size (B, 3)
+        """
+        return face_shape @ rot + trans.unsqueeze(1)
+
+
+    def get_landmarks(self, face_proj):
+        """
+        Return:
+            face_lms         -- torch.tensor, size (B, 68, 2)
+
+        Parameters:
+            face_proj       -- torch.tensor, size (B, N, 2)
+        """  
+        return face_proj[:, self.keypoints]
+
+    def split_coeff(self, coeffs):
+        """
+        Return:
+            coeffs_dict     -- a dict of torch.tensors
+
+        Parameters:
+            coeffs          -- torch.tensor, size (B, 256)
+        """
+        id_coeffs = coeffs[:, :80]
+        exp_coeffs = coeffs[:, 80: 144]
+        tex_coeffs = coeffs[:, 144: 224]
+        angles = coeffs[:, 224: 227]
+        gammas = coeffs[:, 227: 254]
+        translations = coeffs[:, 254:]
+        return {
+            'id': id_coeffs,
+            'exp': exp_coeffs,
+            'tex': tex_coeffs,
+            'angle': angles,
+            'gamma': gammas,
+            'trans': translations
+        }
+    def compute_for_render(self, coeffs):
+        """
+        Return:
+            face_vertex     -- torch.tensor, size (B, N, 3), in camera coordinate
+            face_color      -- torch.tensor, size (B, N, 3), in RGB order
+            landmark        -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+        Parameters:
+            coeffs          -- torch.tensor, size (B, 257)
+        """
+        coef_dict = self.split_coeff(coeffs)
+        face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+        rotation = self.compute_rotation(coef_dict['angle'])
+
+
+        face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+        face_vertex = self.to_camera(face_shape_transformed)
+        
+        face_proj = self.to_image(face_vertex)
+        landmark = self.get_landmarks(face_proj)
+
+        face_texture = self.compute_texture(coef_dict['tex'])
+        face_norm = self.compute_norm(face_shape)
+        face_norm_roted = face_norm @ rotation
+        face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+        return face_vertex, face_texture, face_color, landmark
+
+    def compute_for_render_woRotation(self, coeffs):
+        """
+        Return:
+            face_vertex     -- torch.tensor, size (B, N, 3), in camera coordinate
+            face_color      -- torch.tensor, size (B, N, 3), in RGB order
+            landmark        -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
+        Parameters:
+            coeffs          -- torch.tensor, size (B, 257)
+        """
+        coef_dict = self.split_coeff(coeffs)
+        face_shape = self.compute_shape(coef_dict['id'], coef_dict['exp'])
+        #rotation = self.compute_rotation(coef_dict['angle'])
+
+
+        #face_shape_transformed = self.transform(face_shape, rotation, coef_dict['trans'])
+        face_vertex = self.to_camera(face_shape)
+        
+        face_proj = self.to_image(face_vertex)
+        landmark = self.get_landmarks(face_proj)
+
+        face_texture = self.compute_texture(coef_dict['tex'])
+        face_norm = self.compute_norm(face_shape)
+        face_norm_roted = face_norm                                    # @ rotation
+        face_color = self.compute_color(face_texture, face_norm_roted, coef_dict['gamma'])
+
+        return face_vertex, face_texture, face_color, landmark
+
+
+if __name__ == '__main__':
+    transferBFM09()
\ No newline at end of file
diff --git a/src/face3d/models/facerecon_model.py b/src/face3d/models/facerecon_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40720a35a19b883191cad60d7be4a0ae12a2493
--- /dev/null
+++ b/src/face3d/models/facerecon_model.py
@@ -0,0 +1,220 @@
+"""This script defines the face reconstruction model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import torch
+from src.face3d.models.base_model import BaseModel
+from src.face3d.models import networks
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.losses import perceptual_loss, photo_loss, reg_loss, reflectance_loss, landmark_loss
+from src.face3d.util import util 
+from src.face3d.util.nvdiffrast import MeshRenderer
+# from src.face3d.util.preprocess import estimate_norm_torch
+
+import trimesh
+from scipy.io import savemat
+
+class FaceReconModel(BaseModel):
+
+    @staticmethod
+    def modify_commandline_options(parser, is_train=False):
+        """  Configures options specific for CUT model
+        """
+        # net structure and parameters
+        parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='network structure')
+        parser.add_argument('--init_path', type=str, default='./checkpoints/init_model/resnet50-0676ba61.pth')
+        parser.add_argument('--use_last_fc', type=util.str2bool, nargs='?', const=True, default=False, help='zero initialize the last fc')
+        parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
+        parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
+
+        # renderer parameters
+        parser.add_argument('--focal', type=float, default=1015.)
+        parser.add_argument('--center', type=float, default=112.)
+        parser.add_argument('--camera_d', type=float, default=10.)
+        parser.add_argument('--z_near', type=float, default=5.)
+        parser.add_argument('--z_far', type=float, default=15.)
+
+        if is_train:
+            # training parameters
+            parser.add_argument('--net_recog', type=str, default='r50', choices=['r18', 'r43', 'r50'], help='face recog network structure')
+            parser.add_argument('--net_recog_path', type=str, default='checkpoints/recog_model/ms1mv3_arcface_r50_fp16/backbone.pth')
+            parser.add_argument('--use_crop_face', type=util.str2bool, nargs='?', const=True, default=False, help='use crop mask for photo loss')
+            parser.add_argument('--use_predef_M', type=util.str2bool, nargs='?', const=True, default=False, help='use predefined M for predicted face')
+
+            
+            # augmentation parameters
+            parser.add_argument('--shift_pixs', type=float, default=10., help='shift pixels')
+            parser.add_argument('--scale_delta', type=float, default=0.1, help='delta scale factor')
+            parser.add_argument('--rot_angle', type=float, default=10., help='rot angles, degree')
+
+            # loss weights
+            parser.add_argument('--w_feat', type=float, default=0.2, help='weight for feat loss')
+            parser.add_argument('--w_color', type=float, default=1.92, help='weight for loss loss')
+            parser.add_argument('--w_reg', type=float, default=3.0e-4, help='weight for reg loss')
+            parser.add_argument('--w_id', type=float, default=1.0, help='weight for id_reg loss')
+            parser.add_argument('--w_exp', type=float, default=0.8, help='weight for exp_reg loss')
+            parser.add_argument('--w_tex', type=float, default=1.7e-2, help='weight for tex_reg loss')
+            parser.add_argument('--w_gamma', type=float, default=10.0, help='weight for gamma loss')
+            parser.add_argument('--w_lm', type=float, default=1.6e-3, help='weight for lm loss')
+            parser.add_argument('--w_reflc', type=float, default=5.0, help='weight for reflc loss')
+
+        opt, _ = parser.parse_known_args()
+        parser.set_defaults(
+                focal=1015., center=112., camera_d=10., use_last_fc=False, z_near=5., z_far=15.
+            )
+        if is_train:
+            parser.set_defaults(
+                use_crop_face=True, use_predef_M=False
+            )
+        return parser
+
+    def __init__(self, opt):
+        """Initialize this model class.
+
+        Parameters:
+            opt -- training/test options
+
+        A few things can be done here.
+        - (required) call the initialization function of BaseModel
+        - define loss function, visualization images, model names, and optimizers
+        """
+        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel
+        
+        self.visual_names = ['output_vis']
+        self.model_names = ['net_recon']
+        self.parallel_names = self.model_names + ['renderer']
+
+        self.facemodel = ParametricFaceModel(
+            bfm_folder=opt.bfm_folder, camera_distance=opt.camera_d, focal=opt.focal, center=opt.center,
+            is_train=self.isTrain, default_name=opt.bfm_model
+        )
+        
+        fov = 2 * np.arctan(opt.center / opt.focal) * 180 / np.pi
+        self.renderer = MeshRenderer(
+            rasterize_fov=fov, znear=opt.z_near, zfar=opt.z_far, rasterize_size=int(2 * opt.center)
+        )
+
+        if self.isTrain:
+            self.loss_names = ['all', 'feat', 'color', 'lm', 'reg', 'gamma', 'reflc']
+
+            self.net_recog = networks.define_net_recog(
+                net_recog=opt.net_recog, pretrained_path=opt.net_recog_path
+                )
+            # loss func name: (compute_%s_loss) % loss_name
+            self.compute_feat_loss = perceptual_loss
+            self.comupte_color_loss = photo_loss
+            self.compute_lm_loss = landmark_loss
+            self.compute_reg_loss = reg_loss
+            self.compute_reflc_loss = reflectance_loss
+
+            self.optimizer = torch.optim.Adam(self.net_recon.parameters(), lr=opt.lr)
+            self.optimizers = [self.optimizer]
+            self.parallel_names += ['net_recog']
+        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
+
+    def set_input(self, input):
+        """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+        Parameters:
+            input: a dictionary that contains the data itself and its metadata information.
+        """
+        self.input_img = input['imgs'].to(self.device) 
+        self.atten_mask = input['msks'].to(self.device) if 'msks' in input else None
+        self.gt_lm = input['lms'].to(self.device)  if 'lms' in input else None
+        self.trans_m = input['M'].to(self.device) if 'M' in input else None
+        self.image_paths = input['im_paths'] if 'im_paths' in input else None
+
+    def forward(self, output_coeff, device):
+        self.facemodel.to(device)
+        self.pred_vertex, self.pred_tex, self.pred_color, self.pred_lm = \
+            self.facemodel.compute_for_render(output_coeff)
+        self.pred_mask, _, self.pred_face = self.renderer(
+            self.pred_vertex, self.facemodel.face_buf, feat=self.pred_color)
+        
+        self.pred_coeffs_dict = self.facemodel.split_coeff(output_coeff)
+
+
+    def compute_losses(self):
+        """Calculate losses, gradients, and update network weights; called in every training iteration"""
+
+        assert self.net_recog.training == False
+        trans_m = self.trans_m
+        if not self.opt.use_predef_M:
+            trans_m = estimate_norm_torch(self.pred_lm, self.input_img.shape[-2])
+
+        pred_feat = self.net_recog(self.pred_face, trans_m)
+        gt_feat = self.net_recog(self.input_img, self.trans_m)
+        self.loss_feat = self.opt.w_feat * self.compute_feat_loss(pred_feat, gt_feat)
+
+        face_mask = self.pred_mask
+        if self.opt.use_crop_face:
+            face_mask, _, _ = self.renderer(self.pred_vertex, self.facemodel.front_face_buf)
+        
+        face_mask = face_mask.detach()
+        self.loss_color = self.opt.w_color * self.comupte_color_loss(
+            self.pred_face, self.input_img, self.atten_mask * face_mask)
+        
+        loss_reg, loss_gamma = self.compute_reg_loss(self.pred_coeffs_dict, self.opt)
+        self.loss_reg = self.opt.w_reg * loss_reg
+        self.loss_gamma = self.opt.w_gamma * loss_gamma
+
+        self.loss_lm = self.opt.w_lm * self.compute_lm_loss(self.pred_lm, self.gt_lm)
+
+        self.loss_reflc = self.opt.w_reflc * self.compute_reflc_loss(self.pred_tex, self.facemodel.skin_mask)
+
+        self.loss_all = self.loss_feat + self.loss_color + self.loss_reg + self.loss_gamma \
+                        + self.loss_lm + self.loss_reflc
+            
+
+    def optimize_parameters(self, isTrain=True):
+        self.forward()               
+        self.compute_losses()
+        """Update network weights; it will be called in every training iteration."""
+        if isTrain:
+            self.optimizer.zero_grad()  
+            self.loss_all.backward()         
+            self.optimizer.step()        
+
+    def compute_visuals(self):
+        with torch.no_grad():
+            input_img_numpy = 255. * self.input_img.detach().cpu().permute(0, 2, 3, 1).numpy()
+            output_vis = self.pred_face * self.pred_mask + (1 - self.pred_mask) * self.input_img
+            output_vis_numpy_raw = 255. * output_vis.detach().cpu().permute(0, 2, 3, 1).numpy()
+            
+            if self.gt_lm is not None:
+                gt_lm_numpy = self.gt_lm.cpu().numpy()
+                pred_lm_numpy = self.pred_lm.detach().cpu().numpy()
+                output_vis_numpy = util.draw_landmarks(output_vis_numpy_raw, gt_lm_numpy, 'b')
+                output_vis_numpy = util.draw_landmarks(output_vis_numpy, pred_lm_numpy, 'r')
+            
+                output_vis_numpy = np.concatenate((input_img_numpy, 
+                                    output_vis_numpy_raw, output_vis_numpy), axis=-2)
+            else:
+                output_vis_numpy = np.concatenate((input_img_numpy, 
+                                    output_vis_numpy_raw), axis=-2)
+
+            self.output_vis = torch.tensor(
+                    output_vis_numpy / 255., dtype=torch.float32
+                ).permute(0, 3, 1, 2).to(self.device)
+
+    def save_mesh(self, name):
+
+        recon_shape = self.pred_vertex  # get reconstructed shape
+        recon_shape[..., -1] = 10 - recon_shape[..., -1] # from camera space to world space
+        recon_shape = recon_shape.cpu().numpy()[0]
+        recon_color = self.pred_color
+        recon_color = recon_color.cpu().numpy()[0]
+        tri = self.facemodel.face_buf.cpu().numpy()
+        mesh = trimesh.Trimesh(vertices=recon_shape, faces=tri, vertex_colors=np.clip(255. * recon_color, 0, 255).astype(np.uint8))
+        mesh.export(name)
+
+    def save_coeff(self,name):
+
+        pred_coeffs = {key:self.pred_coeffs_dict[key].cpu().numpy() for key in self.pred_coeffs_dict}
+        pred_lm = self.pred_lm.cpu().numpy()
+        pred_lm = np.stack([pred_lm[:,:,0],self.input_img.shape[2]-1-pred_lm[:,:,1]],axis=2) # transfer to image coordinate
+        pred_coeffs['lm68'] = pred_lm
+        savemat(name,pred_coeffs)
+
+
+
diff --git a/src/face3d/models/losses.py b/src/face3d/models/losses.py
new file mode 100644
index 0000000000000000000000000000000000000000..01d9da84f28d54e772bebd2385ae5a7fedd10f7d
--- /dev/null
+++ b/src/face3d/models/losses.py
@@ -0,0 +1,113 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from kornia.geometry import warp_affine
+import torch.nn.functional as F
+
+def resize_n_crop(image, M, dsize=112):
+    # image: (b, c, h, w)
+    # M   :  (b, 2, 3)
+    return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+### perceptual level loss
+class PerceptualLoss(nn.Module):
+    def __init__(self, recog_net, input_size=112):
+        super(PerceptualLoss, self).__init__()
+        self.recog_net = recog_net
+        self.preprocess = lambda x: 2 * x - 1
+        self.input_size=input_size
+    def forward(imageA, imageB, M):
+        """
+        1 - cosine distance
+        Parameters:
+            imageA       --torch.tensor (B, 3, H, W), range (0, 1) , RGB order
+            imageB       --same as imageA
+        """
+
+        imageA = self.preprocess(resize_n_crop(imageA, M, self.input_size))
+        imageB = self.preprocess(resize_n_crop(imageB, M, self.input_size))
+
+        # freeze bn
+        self.recog_net.eval()
+        
+        id_featureA = F.normalize(self.recog_net(imageA), dim=-1, p=2)
+        id_featureB = F.normalize(self.recog_net(imageB), dim=-1, p=2)  
+        cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+        # assert torch.sum((cosine_d > 1).float()) == 0
+        return torch.sum(1 - cosine_d) / cosine_d.shape[0]        
+
+def perceptual_loss(id_featureA, id_featureB):
+    cosine_d = torch.sum(id_featureA * id_featureB, dim=-1)
+        # assert torch.sum((cosine_d > 1).float()) == 0
+    return torch.sum(1 - cosine_d) / cosine_d.shape[0]  
+
+### image level loss
+def photo_loss(imageA, imageB, mask, eps=1e-6):
+    """
+    l2 norm (with sqrt, to ensure backward stabililty, use eps, otherwise Nan may occur)
+    Parameters:
+        imageA       --torch.tensor (B, 3, H, W), range (0, 1), RGB order 
+        imageB       --same as imageA
+    """
+    loss = torch.sqrt(eps + torch.sum((imageA - imageB) ** 2, dim=1, keepdims=True)) * mask
+    loss = torch.sum(loss) / torch.max(torch.sum(mask), torch.tensor(1.0).to(mask.device))
+    return loss
+
+def landmark_loss(predict_lm, gt_lm, weight=None):
+    """
+    weighted mse loss
+    Parameters:
+        predict_lm    --torch.tensor (B, 68, 2)
+        gt_lm         --torch.tensor (B, 68, 2)
+        weight        --numpy.array (1, 68)
+    """
+    if not weight:
+        weight = np.ones([68])
+        weight[28:31] = 20
+        weight[-8:] = 20
+        weight = np.expand_dims(weight, 0)
+        weight = torch.tensor(weight).to(predict_lm.device)
+    loss = torch.sum((predict_lm - gt_lm)**2, dim=-1) * weight
+    loss = torch.sum(loss) / (predict_lm.shape[0] * predict_lm.shape[1])
+    return loss
+
+
+### regulization
+def reg_loss(coeffs_dict, opt=None):
+    """
+    l2 norm without the sqrt, from yu's implementation (mse)
+    tf.nn.l2_loss https://www.tensorflow.org/api_docs/python/tf/nn/l2_loss
+    Parameters:
+        coeffs_dict     -- a  dict of torch.tensors , keys: id, exp, tex, angle, gamma, trans
+
+    """
+    # coefficient regularization to ensure plausible 3d faces
+    if opt:
+        w_id, w_exp, w_tex = opt.w_id, opt.w_exp, opt.w_tex
+    else:
+        w_id, w_exp, w_tex = 1, 1, 1, 1
+    creg_loss = w_id * torch.sum(coeffs_dict['id'] ** 2) +  \
+           w_exp * torch.sum(coeffs_dict['exp'] ** 2) + \
+           w_tex * torch.sum(coeffs_dict['tex'] ** 2)
+    creg_loss = creg_loss / coeffs_dict['id'].shape[0]
+
+    # gamma regularization to ensure a nearly-monochromatic light
+    gamma = coeffs_dict['gamma'].reshape([-1, 3, 9])
+    gamma_mean = torch.mean(gamma, dim=1, keepdims=True)
+    gamma_loss = torch.mean((gamma - gamma_mean) ** 2)
+
+    return creg_loss, gamma_loss
+
+def reflectance_loss(texture, mask):
+    """
+    minimize texture variance (mse), albedo regularization to ensure an uniform skin albedo
+    Parameters:
+        texture       --torch.tensor, (B, N, 3)
+        mask          --torch.tensor, (N), 1 or 0
+
+    """
+    mask = mask.reshape([1, mask.shape[0], 1])
+    texture_mean = torch.sum(mask * texture, dim=1, keepdims=True) / torch.sum(mask)
+    loss = torch.sum(((texture - texture_mean) * mask)**2) / (texture.shape[0] * torch.sum(mask))
+    return loss
+
diff --git a/src/face3d/models/networks.py b/src/face3d/models/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e69eba1ade2e6431e7e7fd526ea68b8f63e7152
--- /dev/null
+++ b/src/face3d/models/networks.py
@@ -0,0 +1,521 @@
+"""This script defines deep neural networks for Deep3DFaceRecon_pytorch
+"""
+
+import os
+import numpy as np
+import torch.nn.functional as F
+from torch.nn import init
+import functools
+from torch.optim import lr_scheduler
+import torch
+from torch import Tensor
+import torch.nn as nn
+try:
+    from torch.hub import load_state_dict_from_url
+except ImportError:
+    from torch.utils.model_zoo import load_url as load_state_dict_from_url
+from typing import Type, Any, Callable, Union, List, Optional
+from .arcface_torch.backbones import get_model
+from kornia.geometry import warp_affine
+
+def resize_n_crop(image, M, dsize=112):
+    # image: (b, c, h, w)
+    # M   :  (b, 2, 3)
+    return warp_affine(image, M, dsize=(dsize, dsize), align_corners=True)
+
+def filter_state_dict(state_dict, remove_name='fc'):
+    new_state_dict = {}
+    for key in state_dict:
+        if remove_name in key:
+            continue
+        new_state_dict[key] = state_dict[key]
+    return new_state_dict
+
+def get_scheduler(optimizer, opt):
+    """Return a learning rate scheduler
+
+    Parameters:
+        optimizer          -- the optimizer of the network
+        opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions. 
+                              opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
+
+    For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
+    See https://pytorch.org/docs/stable/optim.html for more details.
+    """
+    if opt.lr_policy == 'linear':
+        def lambda_rule(epoch):
+            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs + 1)
+            return lr_l
+        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
+    elif opt.lr_policy == 'step':
+        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_epochs, gamma=0.2)
+    elif opt.lr_policy == 'plateau':
+        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
+    elif opt.lr_policy == 'cosine':
+        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
+    else:
+        return NotImplementedError('learning rate policy [%s] is not implemented', opt.lr_policy)
+    return scheduler
+
+
+def define_net_recon(net_recon, use_last_fc=False, init_path=None):
+    return ReconNetWrapper(net_recon, use_last_fc=use_last_fc, init_path=init_path)
+
+def define_net_recog(net_recog, pretrained_path=None):
+    net = RecogNetWrapper(net_recog=net_recog, pretrained_path=pretrained_path)
+    net.eval()
+    return net
+
+class ReconNetWrapper(nn.Module):
+    fc_dim=257
+    def __init__(self, net_recon, use_last_fc=False, init_path=None):
+        super(ReconNetWrapper, self).__init__()
+        self.use_last_fc = use_last_fc
+        if net_recon not in func_dict:
+            return  NotImplementedError('network [%s] is not implemented', net_recon)
+        func, last_dim = func_dict[net_recon]
+        backbone = func(use_last_fc=use_last_fc, num_classes=self.fc_dim)
+        if init_path and os.path.isfile(init_path):
+            state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
+            backbone.load_state_dict(state_dict)
+            print("loading init net_recon %s from %s" %(net_recon, init_path))
+        self.backbone = backbone
+        if not use_last_fc:
+            self.final_layers = nn.ModuleList([
+                conv1x1(last_dim, 80, bias=True), # id layer
+                conv1x1(last_dim, 64, bias=True), # exp layer
+                conv1x1(last_dim, 80, bias=True), # tex layer
+                conv1x1(last_dim, 3, bias=True),  # angle layer
+                conv1x1(last_dim, 27, bias=True), # gamma layer
+                conv1x1(last_dim, 2, bias=True),  # tx, ty
+                conv1x1(last_dim, 1, bias=True)   # tz
+            ])
+            for m in self.final_layers:
+                nn.init.constant_(m.weight, 0.)
+                nn.init.constant_(m.bias, 0.)
+
+    def forward(self, x):
+        x = self.backbone(x)
+        if not self.use_last_fc:
+            output = []
+            for layer in self.final_layers:
+                output.append(layer(x))
+            x = torch.flatten(torch.cat(output, dim=1), 1)
+        return x
+
+
+class RecogNetWrapper(nn.Module):
+    def __init__(self, net_recog, pretrained_path=None, input_size=112):
+        super(RecogNetWrapper, self).__init__()
+        net = get_model(name=net_recog, fp16=False)
+        if pretrained_path:
+            state_dict = torch.load(pretrained_path, map_location='cpu')
+            net.load_state_dict(state_dict)
+            print("loading pretrained net_recog %s from %s" %(net_recog, pretrained_path))
+        for param in net.parameters():
+            param.requires_grad = False
+        self.net = net
+        self.preprocess = lambda x: 2 * x - 1
+        self.input_size=input_size
+        
+    def forward(self, image, M):
+        image = self.preprocess(resize_n_crop(image, M, self.input_size))
+        id_feature = F.normalize(self.net(image), dim=-1, p=2)
+        return id_feature
+
+
+# adapted from https://github.com/pytorch/vision/edit/master/torchvision/models/resnet.py
+__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
+           'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
+           'wide_resnet50_2', 'wide_resnet101_2']
+
+
+model_urls = {
+    'resnet18': 'https://download.pytorch.org/models/resnet18-f37072fd.pth',
+    'resnet34': 'https://download.pytorch.org/models/resnet34-b627a593.pth',
+    'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
+    'resnet101': 'https://download.pytorch.org/models/resnet101-63fe2227.pth',
+    'resnet152': 'https://download.pytorch.org/models/resnet152-394f9c45.pth',
+    'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
+    'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
+    'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
+    'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
+}
+
+
+def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
+    """3x3 convolution with padding"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
+                     padding=dilation, groups=groups, bias=False, dilation=dilation)
+
+
+def conv1x1(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
+    """1x1 convolution"""
+    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
+
+
+class BasicBlock(nn.Module):
+    expansion: int = 1
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        norm_layer: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(BasicBlock, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        if groups != 1 or base_width != 64:
+            raise ValueError('BasicBlock only supports groups=1 and base_width=64')
+        if dilation > 1:
+            raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
+        # Both self.conv1 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv3x3(inplanes, planes, stride)
+        self.bn1 = norm_layer(planes)
+        self.relu = nn.ReLU(inplace=True)
+        self.conv2 = conv3x3(planes, planes)
+        self.bn2 = norm_layer(planes)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: Tensor) -> Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class Bottleneck(nn.Module):
+    # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
+    # while original implementation places the stride at the first 1x1 convolution(self.conv1)
+    # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
+    # This variant is also known as ResNet V1.5 and improves accuracy according to
+    # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
+
+    expansion: int = 4
+
+    def __init__(
+        self,
+        inplanes: int,
+        planes: int,
+        stride: int = 1,
+        downsample: Optional[nn.Module] = None,
+        groups: int = 1,
+        base_width: int = 64,
+        dilation: int = 1,
+        norm_layer: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(Bottleneck, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        width = int(planes * (base_width / 64.)) * groups
+        # Both self.conv2 and self.downsample layers downsample the input when stride != 1
+        self.conv1 = conv1x1(inplanes, width)
+        self.bn1 = norm_layer(width)
+        self.conv2 = conv3x3(width, width, stride, groups, dilation)
+        self.bn2 = norm_layer(width)
+        self.conv3 = conv1x1(width, planes * self.expansion)
+        self.bn3 = norm_layer(planes * self.expansion)
+        self.relu = nn.ReLU(inplace=True)
+        self.downsample = downsample
+        self.stride = stride
+
+    def forward(self, x: Tensor) -> Tensor:
+        identity = x
+
+        out = self.conv1(x)
+        out = self.bn1(out)
+        out = self.relu(out)
+
+        out = self.conv2(out)
+        out = self.bn2(out)
+        out = self.relu(out)
+
+        out = self.conv3(out)
+        out = self.bn3(out)
+
+        if self.downsample is not None:
+            identity = self.downsample(x)
+
+        out += identity
+        out = self.relu(out)
+
+        return out
+
+
+class ResNet(nn.Module):
+
+    def __init__(
+        self,
+        block: Type[Union[BasicBlock, Bottleneck]],
+        layers: List[int],
+        num_classes: int = 1000,
+        zero_init_residual: bool = False,
+        use_last_fc: bool = False,
+        groups: int = 1,
+        width_per_group: int = 64,
+        replace_stride_with_dilation: Optional[List[bool]] = None,
+        norm_layer: Optional[Callable[..., nn.Module]] = None
+    ) -> None:
+        super(ResNet, self).__init__()
+        if norm_layer is None:
+            norm_layer = nn.BatchNorm2d
+        self._norm_layer = norm_layer
+
+        self.inplanes = 64
+        self.dilation = 1
+        if replace_stride_with_dilation is None:
+            # each element in the tuple indicates if we should replace
+            # the 2x2 stride with a dilated convolution instead
+            replace_stride_with_dilation = [False, False, False]
+        if len(replace_stride_with_dilation) != 3:
+            raise ValueError("replace_stride_with_dilation should be None "
+                             "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
+        self.use_last_fc = use_last_fc
+        self.groups = groups
+        self.base_width = width_per_group
+        self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
+                               bias=False)
+        self.bn1 = norm_layer(self.inplanes)
+        self.relu = nn.ReLU(inplace=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+        self.layer1 = self._make_layer(block, 64, layers[0])
+        self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
+                                       dilate=replace_stride_with_dilation[0])
+        self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
+                                       dilate=replace_stride_with_dilation[1])
+        self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
+                                       dilate=replace_stride_with_dilation[2])
+        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
+        
+        if self.use_last_fc:
+            self.fc = nn.Linear(512 * block.expansion, num_classes)
+
+        for m in self.modules():
+            if isinstance(m, nn.Conv2d):
+                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+            elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
+                nn.init.constant_(m.weight, 1)
+                nn.init.constant_(m.bias, 0)
+
+
+
+        # Zero-initialize the last BN in each residual branch,
+        # so that the residual branch starts with zeros, and each residual block behaves like an identity.
+        # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
+        if zero_init_residual:
+            for m in self.modules():
+                if isinstance(m, Bottleneck):
+                    nn.init.constant_(m.bn3.weight, 0)  # type: ignore[arg-type]
+                elif isinstance(m, BasicBlock):
+                    nn.init.constant_(m.bn2.weight, 0)  # type: ignore[arg-type]
+
+    def _make_layer(self, block: Type[Union[BasicBlock, Bottleneck]], planes: int, blocks: int,
+                    stride: int = 1, dilate: bool = False) -> nn.Sequential:
+        norm_layer = self._norm_layer
+        downsample = None
+        previous_dilation = self.dilation
+        if dilate:
+            self.dilation *= stride
+            stride = 1
+        if stride != 1 or self.inplanes != planes * block.expansion:
+            downsample = nn.Sequential(
+                conv1x1(self.inplanes, planes * block.expansion, stride),
+                norm_layer(planes * block.expansion),
+            )
+
+        layers = []
+        layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
+                            self.base_width, previous_dilation, norm_layer))
+        self.inplanes = planes * block.expansion
+        for _ in range(1, blocks):
+            layers.append(block(self.inplanes, planes, groups=self.groups,
+                                base_width=self.base_width, dilation=self.dilation,
+                                norm_layer=norm_layer))
+
+        return nn.Sequential(*layers)
+
+    def _forward_impl(self, x: Tensor) -> Tensor:
+        # See note [TorchScript super()]
+        x = self.conv1(x)
+        x = self.bn1(x)
+        x = self.relu(x)
+        x = self.maxpool(x)
+
+        x = self.layer1(x)
+        x = self.layer2(x)
+        x = self.layer3(x)
+        x = self.layer4(x)
+
+        x = self.avgpool(x)
+        if self.use_last_fc:
+            x = torch.flatten(x, 1)
+            x = self.fc(x)
+        return x
+
+    def forward(self, x: Tensor) -> Tensor:
+        return self._forward_impl(x)
+
+
+def _resnet(
+    arch: str,
+    block: Type[Union[BasicBlock, Bottleneck]],
+    layers: List[int],
+    pretrained: bool,
+    progress: bool,
+    **kwargs: Any
+) -> ResNet:
+    model = ResNet(block, layers, **kwargs)
+    if pretrained:
+        state_dict = load_state_dict_from_url(model_urls[arch],
+                                              progress=progress)
+        model.load_state_dict(state_dict)
+    return model
+
+
+def resnet18(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNet-18 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
+                   **kwargs)
+
+
+def resnet34(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNet-34 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNet-50 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet101(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNet-101 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnet152(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNet-152 model from
+    `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
+                   **kwargs)
+
+
+def resnext50_32x4d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNeXt-50 32x4d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 4
+    return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def resnext101_32x8d(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""ResNeXt-101 32x8d model from
+    `"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['groups'] = 32
+    kwargs['width_per_group'] = 8
+    return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet50_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""Wide ResNet-50-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
+
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
+                   pretrained, progress, **kwargs)
+
+
+def wide_resnet101_2(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> ResNet:
+    r"""Wide ResNet-101-2 model from
+    `"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_.
+
+    The model is the same as ResNet except for the bottleneck number of channels
+    which is twice larger in every block. The number of channels in outer 1x1
+    convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
+    channels, and in Wide ResNet-50-2 has 2048-1024-2048.
+
+    Args:
+        pretrained (bool): If True, returns a model pre-trained on ImageNet
+        progress (bool): If True, displays a progress bar of the download to stderr
+    """
+    kwargs['width_per_group'] = 64 * 2
+    return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
+                   pretrained, progress, **kwargs)
+
+
+func_dict = {
+    'resnet18': (resnet18, 512),
+    'resnet50': (resnet50, 2048)
+}
diff --git a/src/face3d/models/template_model.py b/src/face3d/models/template_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..75860272a06312bfa4de382729dce5136a480a7f
--- /dev/null
+++ b/src/face3d/models/template_model.py
@@ -0,0 +1,100 @@
+"""Model class template
+
+This module provides a template for users to implement custom models.
+You can specify '--model template' to use this model.
+The class name should be consistent with both the filename and its model option.
+The filename should be <model>_dataset.py
+The class name should be <Model>Dataset.py
+It implements a simple image-to-image translation baseline based on regression loss.
+Given input-output pairs (data_A, data_B), it learns a network netG that can minimize the following L1 loss:
+    min_<netG> ||netG(data_A) - data_B||_1
+You need to implement the following functions:
+    <modify_commandline_options>: Add model-specific options and rewrite default values for existing options.
+    <__init__>: Initialize this model class.
+    <set_input>: Unpack input data and perform data pre-processing.
+    <forward>: Run forward pass. This will be called by both <optimize_parameters> and <test>.
+    <optimize_parameters>: Update network weights; it will be called in every training iteration.
+"""
+import numpy as np
+import torch
+from .base_model import BaseModel
+from . import networks
+
+
+class TemplateModel(BaseModel):
+    @staticmethod
+    def modify_commandline_options(parser, is_train=True):
+        """Add new model-specific options and rewrite default values for existing options.
+
+        Parameters:
+            parser -- the option parser
+            is_train -- if it is training phase or test phase. You can use this flag to add training-specific or test-specific options.
+
+        Returns:
+            the modified parser.
+        """
+        parser.set_defaults(dataset_mode='aligned')  # You can rewrite default values for this model. For example, this model usually uses aligned dataset as its dataset.
+        if is_train:
+            parser.add_argument('--lambda_regression', type=float, default=1.0, help='weight for the regression loss')  # You can define new arguments for this model.
+
+        return parser
+
+    def __init__(self, opt):
+        """Initialize this model class.
+
+        Parameters:
+            opt -- training/test options
+
+        A few things can be done here.
+        - (required) call the initialization function of BaseModel
+        - define loss function, visualization images, model names, and optimizers
+        """
+        BaseModel.__init__(self, opt)  # call the initialization method of BaseModel
+        # specify the training losses you want to print out. The program will call base_model.get_current_losses to plot the losses to the console and save them to the disk.
+        self.loss_names = ['loss_G']
+        # specify the images you want to save and display. The program will call base_model.get_current_visuals to save and display these images.
+        self.visual_names = ['data_A', 'data_B', 'output']
+        # specify the models you want to save to the disk. The program will call base_model.save_networks and base_model.load_networks to save and load networks.
+        # you can use opt.isTrain to specify different behaviors for training and test. For example, some networks will not be used during test, and you don't need to load them.
+        self.model_names = ['G']
+        # define networks; you can use opt.isTrain to specify different behaviors for training and test.
+        self.netG = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, gpu_ids=self.gpu_ids)
+        if self.isTrain:  # only defined during training time
+            # define your loss functions. You can use losses provided by torch.nn such as torch.nn.L1Loss.
+            # We also provide a GANLoss class "networks.GANLoss". self.criterionGAN = networks.GANLoss().to(self.device)
+            self.criterionLoss = torch.nn.L1Loss()
+            # define and initialize optimizers. You can define one optimizer for each network.
+            # If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
+            self.optimizer = torch.optim.Adam(self.netG.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
+            self.optimizers = [self.optimizer]
+
+        # Our program will automatically call <model.setup> to define schedulers, load networks, and print networks
+
+    def set_input(self, input):
+        """Unpack input data from the dataloader and perform necessary pre-processing steps.
+
+        Parameters:
+            input: a dictionary that contains the data itself and its metadata information.
+        """
+        AtoB = self.opt.direction == 'AtoB'  # use <direction> to swap data_A and data_B
+        self.data_A = input['A' if AtoB else 'B'].to(self.device)  # get image data A
+        self.data_B = input['B' if AtoB else 'A'].to(self.device)  # get image data B
+        self.image_paths = input['A_paths' if AtoB else 'B_paths']  # get image paths
+
+    def forward(self):
+        """Run forward pass. This will be called by both functions <optimize_parameters> and <test>."""
+        self.output = self.netG(self.data_A)  # generate output image given the input data_A
+
+    def backward(self):
+        """Calculate losses, gradients, and update network weights; called in every training iteration"""
+        # caculate the intermediate results if necessary; here self.output has been computed during function <forward>
+        # calculate loss given the input and intermediate results
+        self.loss_G = self.criterionLoss(self.output, self.data_B) * self.opt.lambda_regression
+        self.loss_G.backward()       # calculate gradients of network G w.r.t. loss_G
+
+    def optimize_parameters(self):
+        """Update network weights; it will be called in every training iteration."""
+        self.forward()               # first call forward to calculate intermediate results
+        self.optimizer.zero_grad()   # clear network G's existing gradients
+        self.backward()              # calculate gradients for network G
+        self.optimizer.step()        # update gradients for network G
diff --git a/src/face3d/options/__init__.py b/src/face3d/options/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..06559aa558cf178b946c4523b28b098d1dfad606
--- /dev/null
+++ b/src/face3d/options/__init__.py
@@ -0,0 +1 @@
+"""This package options includes option modules: training options, test options, and basic options (used in both training and test)."""
diff --git a/src/face3d/options/base_options.py b/src/face3d/options/base_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..23782ac32d32510c7fbab6c06a58ce98722896ce
--- /dev/null
+++ b/src/face3d/options/base_options.py
@@ -0,0 +1,169 @@
+"""This script contains base options for Deep3DFaceRecon_pytorch
+"""
+
+import argparse
+import os
+from util import util
+import numpy as np
+import torch
+import face3d.models as models
+import face3d.data as data
+
+
+class BaseOptions():
+    """This class defines options used during both training and test time.
+
+    It also implements several helper functions such as parsing, printing, and saving the options.
+    It also gathers additional options defined in <modify_commandline_options> functions in both dataset class and model class.
+    """
+
+    def __init__(self, cmd_line=None):
+        """Reset the class; indicates the class hasn't been initailized"""
+        self.initialized = False
+        self.cmd_line = None
+        if cmd_line is not None:
+            self.cmd_line = cmd_line.split()
+
+    def initialize(self, parser):
+        """Define the common options that are used in both training and test."""
+        # basic parameters
+        parser.add_argument('--name', type=str, default='face_recon', help='name of the experiment. It decides where to store samples and models')
+        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0  0,1,2, 0,2. use -1 for CPU')
+        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
+        parser.add_argument('--vis_batch_nums', type=float, default=1, help='batch nums of images for visulization')
+        parser.add_argument('--eval_batch_nums', type=float, default=float('inf'), help='batch nums of images for evaluation')
+        parser.add_argument('--use_ddp', type=util.str2bool, nargs='?', const=True, default=True, help='whether use distributed data parallel')
+        parser.add_argument('--ddp_port', type=str, default='12355', help='ddp port')
+        parser.add_argument('--display_per_batch', type=util.str2bool, nargs='?', const=True, default=True, help='whether use batch to show losses')
+        parser.add_argument('--add_image', type=util.str2bool, nargs='?', const=True, default=True, help='whether add image to tensorboard')
+        parser.add_argument('--world_size', type=int, default=1, help='batch nums of images for evaluation')
+
+        # model parameters
+        parser.add_argument('--model', type=str, default='facerecon', help='chooses which model to use.')
+
+        # additional parameters
+        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load? set to latest to use latest cached model')
+        parser.add_argument('--verbose', action='store_true', help='if specified, print more debugging information')
+        parser.add_argument('--suffix', default='', type=str, help='customized suffix: opt.name = opt.name + suffix: e.g., {model}_{netG}_size{load_size}')
+
+        self.initialized = True
+        return parser
+
+    def gather_options(self):
+        """Initialize our parser with basic options(only once).
+        Add additional model-specific and dataset-specific options.
+        These options are defined in the <modify_commandline_options> function
+        in model and dataset classes.
+        """
+        if not self.initialized:  # check if it has been initialized
+            parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+            parser = self.initialize(parser)
+
+        # get the basic options
+        if self.cmd_line is None:
+            opt, _ = parser.parse_known_args()
+        else:
+            opt, _ = parser.parse_known_args(self.cmd_line)
+
+        # set cuda visible devices
+        os.environ['CUDA_VISIBLE_DEVICES'] = opt.gpu_ids
+
+        # modify model-related parser options
+        model_name = opt.model
+        model_option_setter = models.get_option_setter(model_name)
+        parser = model_option_setter(parser, self.isTrain)
+        if self.cmd_line is None:
+            opt, _ = parser.parse_known_args()  # parse again with new defaults
+        else:
+            opt, _ = parser.parse_known_args(self.cmd_line)  # parse again with new defaults
+
+        # modify dataset-related parser options
+        if opt.dataset_mode:
+            dataset_name = opt.dataset_mode
+            dataset_option_setter = data.get_option_setter(dataset_name)
+            parser = dataset_option_setter(parser, self.isTrain)
+
+        # save and return the parser
+        self.parser = parser
+        if self.cmd_line is None:
+            return parser.parse_args()
+        else:
+            return parser.parse_args(self.cmd_line)
+
+    def print_options(self, opt):
+        """Print and save options
+
+        It will print both current options and default values(if different).
+        It will save options into a text file / [checkpoints_dir] / opt.txt
+        """
+        message = ''
+        message += '----------------- Options ---------------\n'
+        for k, v in sorted(vars(opt).items()):
+            comment = ''
+            default = self.parser.get_default(k)
+            if v != default:
+                comment = '\t[default: %s]' % str(default)
+            message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
+        message += '----------------- End -------------------'
+        print(message)
+
+        # save to the disk
+        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
+        util.mkdirs(expr_dir)
+        file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
+        try:
+            with open(file_name, 'wt') as opt_file:
+                opt_file.write(message)
+                opt_file.write('\n')
+        except PermissionError as error:
+            print("permission error {}".format(error))
+            pass
+
+    def parse(self):
+        """Parse our options, create checkpoints directory suffix, and set up gpu device."""
+        opt = self.gather_options()
+        opt.isTrain = self.isTrain   # train or test
+
+        # process opt.suffix
+        if opt.suffix:
+            suffix = ('_' + opt.suffix.format(**vars(opt))) if opt.suffix != '' else ''
+            opt.name = opt.name + suffix
+
+
+        # set gpu ids
+        str_ids = opt.gpu_ids.split(',')
+        gpu_ids = []
+        for str_id in str_ids:
+            id = int(str_id)
+            if id >= 0:
+                gpu_ids.append(id)
+        opt.world_size = len(gpu_ids)
+        # if len(opt.gpu_ids) > 0:
+        #     torch.cuda.set_device(gpu_ids[0])
+        if opt.world_size == 1:
+            opt.use_ddp = False
+
+        if opt.phase != 'test':
+            # set continue_train automatically
+            if opt.pretrained_name is None:
+                model_dir = os.path.join(opt.checkpoints_dir, opt.name)
+            else:
+                model_dir = os.path.join(opt.checkpoints_dir, opt.pretrained_name)
+            if os.path.isdir(model_dir):
+                model_pths = [i for i in os.listdir(model_dir) if i.endswith('pth')]
+                if os.path.isdir(model_dir) and len(model_pths) != 0:
+                    opt.continue_train= True
+        
+            # update the latest epoch count
+            if opt.continue_train:
+                if opt.epoch == 'latest':
+                    epoch_counts = [int(i.split('.')[0].split('_')[-1]) for i in model_pths if 'latest' not in i]
+                    if len(epoch_counts) != 0:
+                        opt.epoch_count = max(epoch_counts) + 1
+                else:
+                    opt.epoch_count = int(opt.epoch) + 1
+                    
+
+        self.print_options(opt)
+        self.opt = opt
+        return self.opt
diff --git a/src/face3d/options/inference_options.py b/src/face3d/options/inference_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b9466776e120e0fe3d164217df5071c2114cef
--- /dev/null
+++ b/src/face3d/options/inference_options.py
@@ -0,0 +1,23 @@
+from face3d.options.base_options import BaseOptions
+
+
+class InferenceOptions(BaseOptions):
+    """This class includes test options.
+
+    It also includes shared options defined in BaseOptions.
+    """
+
+    def initialize(self, parser):
+        parser = BaseOptions.initialize(self, parser)  # define shared options
+        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+        parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+
+        parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+        parser.add_argument('--keypoint_dir', type=str, help='the folder of the keypoint files')
+        parser.add_argument('--output_dir', type=str, default='mp4', help='the output dir to save the extracted coefficients')
+        parser.add_argument('--save_split_files', action='store_true', help='save split files or not')
+        parser.add_argument('--inference_batch_size', type=int, default=8)
+        
+        # Dropout and Batchnorm has different behavior during training and test.
+        self.isTrain = False
+        return parser
diff --git a/src/face3d/options/test_options.py b/src/face3d/options/test_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..f81c0c6eee0549e6fa8762dc4fc4b8573b887fe4
--- /dev/null
+++ b/src/face3d/options/test_options.py
@@ -0,0 +1,21 @@
+"""This script contains the test options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+
+
+class TestOptions(BaseOptions):
+    """This class includes test options.
+
+    It also includes shared options defined in BaseOptions.
+    """
+
+    def initialize(self, parser):
+        parser = BaseOptions.initialize(self, parser)  # define shared options
+        parser.add_argument('--phase', type=str, default='test', help='train, val, test, etc')
+        parser.add_argument('--dataset_mode', type=str, default=None, help='chooses how datasets are loaded. [None | flist]')
+        parser.add_argument('--img_folder', type=str, default='examples', help='folder for test images.')
+
+        # Dropout and Batchnorm has different behavior during training and test.
+        self.isTrain = False
+        return parser
diff --git a/src/face3d/options/train_options.py b/src/face3d/options/train_options.py
new file mode 100644
index 0000000000000000000000000000000000000000..1100b0e35cc8ef563f41f6b8219510edbef53233
--- /dev/null
+++ b/src/face3d/options/train_options.py
@@ -0,0 +1,53 @@
+"""This script contains the training options for Deep3DFaceRecon_pytorch
+"""
+
+from .base_options import BaseOptions
+from util import util
+
+class TrainOptions(BaseOptions):
+    """This class includes training options.
+
+    It also includes shared options defined in BaseOptions.
+    """
+
+    def initialize(self, parser):
+        parser = BaseOptions.initialize(self, parser)
+        # dataset parameters
+        # for train
+        parser.add_argument('--data_root', type=str, default='./', help='dataset root')
+        parser.add_argument('--flist', type=str, default='datalist/train/masks.txt', help='list of mask names of training set')
+        parser.add_argument('--batch_size', type=int, default=32)
+        parser.add_argument('--dataset_mode', type=str, default='flist', help='chooses how datasets are loaded. [None | flist]')
+        parser.add_argument('--serial_batches', action='store_true', help='if true, takes images in order to make batches, otherwise takes them randomly')
+        parser.add_argument('--num_threads', default=4, type=int, help='# threads for loading data')
+        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='Maximum number of samples allowed per dataset. If the dataset directory contains more than max_dataset_size, only a subset is loaded.')
+        parser.add_argument('--preprocess', type=str, default='shift_scale_rot_flip', help='scaling and cropping of images at load time [shift_scale_rot_flip | shift_scale | shift | shift_rot_flip ]')
+        parser.add_argument('--use_aug', type=util.str2bool, nargs='?', const=True, default=True, help='whether use data augmentation')
+
+        # for val
+        parser.add_argument('--flist_val', type=str, default='datalist/val/masks.txt', help='list of mask names of val set')
+        parser.add_argument('--batch_size_val', type=int, default=32)
+
+
+        # visualization parameters
+        parser.add_argument('--display_freq', type=int, default=1000, help='frequency of showing training results on screen')
+        parser.add_argument('--print_freq', type=int, default=100, help='frequency of showing training results on console')
+        
+        # network saving and loading parameters
+        parser.add_argument('--save_latest_freq', type=int, default=5000, help='frequency of saving the latest results')
+        parser.add_argument('--save_epoch_freq', type=int, default=1, help='frequency of saving checkpoints at the end of epochs')
+        parser.add_argument('--evaluation_freq', type=int, default=5000, help='evaluation freq')
+        parser.add_argument('--save_by_iter', action='store_true', help='whether saves model by iteration')
+        parser.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
+        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
+        parser.add_argument('--phase', type=str, default='train', help='train, val, test, etc')
+        parser.add_argument('--pretrained_name', type=str, default=None, help='resume training from another checkpoint')
+
+        # training parameters
+        parser.add_argument('--n_epochs', type=int, default=20, help='number of epochs with the initial learning rate')
+        parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
+        parser.add_argument('--lr_policy', type=str, default='step', help='learning rate policy. [linear | step | plateau | cosine]')
+        parser.add_argument('--lr_decay_epochs', type=int, default=10, help='multiply by a gamma every lr_decay_epochs epoches')
+
+        self.isTrain = True
+        return parser
diff --git a/src/face3d/util/BBRegressorParam_r.mat b/src/face3d/util/BBRegressorParam_r.mat
new file mode 100644
index 0000000000000000000000000000000000000000..1430a94ed2ab570a09f9d980d3585e8aaa933084
Binary files /dev/null and b/src/face3d/util/BBRegressorParam_r.mat differ
diff --git a/src/face3d/util/__init__.py b/src/face3d/util/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c67833cc634a2ca310b883ae253b08687665f40
--- /dev/null
+++ b/src/face3d/util/__init__.py
@@ -0,0 +1,3 @@
+"""This package includes a miscellaneous collection of useful helper functions."""
+from src.face3d.util import *
+
diff --git a/src/face3d/util/__pycache__/__init__.cpython-38.pyc b/src/face3d/util/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..03375d34e3f7378a72a8dc0c8a31155712918524
Binary files /dev/null and b/src/face3d/util/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/face3d/util/__pycache__/__init__.cpython-39.pyc b/src/face3d/util/__pycache__/__init__.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b22fff82e58fca86defa2984928220f9ea32685d
Binary files /dev/null and b/src/face3d/util/__pycache__/__init__.cpython-39.pyc differ
diff --git a/src/face3d/util/__pycache__/load_mats.cpython-38.pyc b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f993cef20651ddb34688859321c20c4b5fc51171
Binary files /dev/null and b/src/face3d/util/__pycache__/load_mats.cpython-38.pyc differ
diff --git a/src/face3d/util/__pycache__/load_mats.cpython-39.pyc b/src/face3d/util/__pycache__/load_mats.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8bcbfc06614b0ce889cd6b851b6c862ac261696b
Binary files /dev/null and b/src/face3d/util/__pycache__/load_mats.cpython-39.pyc differ
diff --git a/src/face3d/util/__pycache__/preprocess.cpython-38.pyc b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2a45b8960f174f3a04ccf7856b9ec64c336634d1
Binary files /dev/null and b/src/face3d/util/__pycache__/preprocess.cpython-38.pyc differ
diff --git a/src/face3d/util/__pycache__/preprocess.cpython-39.pyc b/src/face3d/util/__pycache__/preprocess.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bb1130cb1e6e9978b79b0212692fd347c542c408
Binary files /dev/null and b/src/face3d/util/__pycache__/preprocess.cpython-39.pyc differ
diff --git a/src/face3d/util/detect_lm68.py b/src/face3d/util/detect_lm68.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a2cfd22b342de5c872ff07fc1c2a9920c2985b7
--- /dev/null
+++ b/src/face3d/util/detect_lm68.py
@@ -0,0 +1,106 @@
+import os
+import cv2
+import numpy as np
+from scipy.io import loadmat
+import tensorflow as tf
+from util.preprocess import align_for_lm
+from shutil import move
+
+mean_face = np.loadtxt('util/test_mean_face.txt')
+mean_face = mean_face.reshape([68, 2])
+
+def save_label(labels, save_path):
+    np.savetxt(save_path, labels)
+
+def draw_landmarks(img, landmark, save_name):
+    landmark = landmark
+    lm_img = np.zeros([img.shape[0], img.shape[1], 3])
+    lm_img[:] = img.astype(np.float32)
+    landmark = np.round(landmark).astype(np.int32)
+
+    for i in range(len(landmark)):
+        for j in range(-1, 1):
+            for k in range(-1, 1):
+                if img.shape[0] - 1 - landmark[i, 1]+j > 0 and \
+                        img.shape[0] - 1 - landmark[i, 1]+j < img.shape[0] and \
+                        landmark[i, 0]+k > 0 and \
+                        landmark[i, 0]+k < img.shape[1]:
+                    lm_img[img.shape[0] - 1 - landmark[i, 1]+j, landmark[i, 0]+k,
+                           :] = np.array([0, 0, 255])
+    lm_img = lm_img.astype(np.uint8)
+
+    cv2.imwrite(save_name, lm_img)
+
+
+def load_data(img_name, txt_name):
+    return cv2.imread(img_name), np.loadtxt(txt_name)
+
+# create tensorflow graph for landmark detector
+def load_lm_graph(graph_filename):
+    with tf.gfile.GFile(graph_filename, 'rb') as f:
+        graph_def = tf.GraphDef()
+        graph_def.ParseFromString(f.read())
+
+    with tf.Graph().as_default() as graph:
+        tf.import_graph_def(graph_def, name='net')
+        img_224 = graph.get_tensor_by_name('net/input_imgs:0')
+        output_lm = graph.get_tensor_by_name('net/lm:0')
+        lm_sess = tf.Session(graph=graph)
+
+    return lm_sess,img_224,output_lm
+
+# landmark detection
+def detect_68p(img_path,sess,input_op,output_op):
+    print('detecting landmarks......')
+    names = [i for i in sorted(os.listdir(
+        img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+    vis_path = os.path.join(img_path, 'vis')
+    remove_path = os.path.join(img_path, 'remove')
+    save_path = os.path.join(img_path, 'landmarks')
+    if not os.path.isdir(vis_path):
+        os.makedirs(vis_path)
+    if not os.path.isdir(remove_path):
+        os.makedirs(remove_path)
+    if not os.path.isdir(save_path):
+        os.makedirs(save_path)
+
+    for i in range(0, len(names)):
+        name = names[i]
+        print('%05d' % (i), ' ', name)
+        full_image_name = os.path.join(img_path, name)
+        txt_name = '.'.join(name.split('.')[:-1]) + '.txt'
+        full_txt_name = os.path.join(img_path, 'detections', txt_name) # 5 facial landmark path for each image
+
+        # if an image does not have detected 5 facial landmarks, remove it from the training list
+        if not os.path.isfile(full_txt_name):
+            move(full_image_name, os.path.join(remove_path, name))
+            continue 
+
+        # load data
+        img, five_points = load_data(full_image_name, full_txt_name)
+        input_img, scale, bbox = align_for_lm(img, five_points) # align for 68 landmark detection 
+
+        # if the alignment fails, remove corresponding image from the training list
+        if scale == 0:
+            move(full_txt_name, os.path.join(
+                remove_path, txt_name))
+            move(full_image_name, os.path.join(remove_path, name))
+            continue
+
+        # detect landmarks
+        input_img = np.reshape(
+            input_img, [1, 224, 224, 3]).astype(np.float32)
+        landmark = sess.run(
+            output_op, feed_dict={input_op: input_img})
+
+        # transform back to original image coordinate
+        landmark = landmark.reshape([68, 2]) + mean_face
+        landmark[:, 1] = 223 - landmark[:, 1]
+        landmark = landmark / scale
+        landmark[:, 0] = landmark[:, 0] + bbox[0]
+        landmark[:, 1] = landmark[:, 1] + bbox[1]
+        landmark[:, 1] = img.shape[0] - 1 - landmark[:, 1]
+
+        if i % 100 == 0:
+            draw_landmarks(img, landmark, os.path.join(vis_path, name))
+        save_label(landmark, os.path.join(save_path, txt_name))
diff --git a/src/face3d/util/generate_list.py b/src/face3d/util/generate_list.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebe93fcc5c61fbc79f4cd004a8d1bdd10ece16eb
--- /dev/null
+++ b/src/face3d/util/generate_list.py
@@ -0,0 +1,34 @@
+"""This script is to generate training list files for Deep3DFaceRecon_pytorch
+"""
+
+import os
+
+# save path to training data
+def write_list(lms_list, imgs_list, msks_list, mode='train',save_folder='datalist', save_name=''):
+    save_path = os.path.join(save_folder, mode)
+    if not os.path.isdir(save_path):
+        os.makedirs(save_path)
+    with open(os.path.join(save_path, save_name + 'landmarks.txt'), 'w') as fd:
+        fd.writelines([i + '\n' for i in lms_list])
+
+    with open(os.path.join(save_path, save_name + 'images.txt'), 'w') as fd:
+        fd.writelines([i + '\n' for i in imgs_list])
+
+    with open(os.path.join(save_path, save_name + 'masks.txt'), 'w') as fd:
+        fd.writelines([i + '\n' for i in msks_list])   
+
+# check if the path is valid
+def check_list(rlms_list, rimgs_list, rmsks_list):
+    lms_list, imgs_list, msks_list = [], [], []
+    for i in range(len(rlms_list)):
+        flag = 'false'
+        lm_path = rlms_list[i]
+        im_path = rimgs_list[i]
+        msk_path = rmsks_list[i]
+        if os.path.isfile(lm_path) and os.path.isfile(im_path) and os.path.isfile(msk_path):
+            flag = 'true'
+            lms_list.append(rlms_list[i])
+            imgs_list.append(rimgs_list[i])
+            msks_list.append(rmsks_list[i])
+        print(i, rlms_list[i], flag)
+    return lms_list, imgs_list, msks_list
diff --git a/src/face3d/util/html.py b/src/face3d/util/html.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0c4e6a66ba5a34e30cee3beb13e21465c72ef38
--- /dev/null
+++ b/src/face3d/util/html.py
@@ -0,0 +1,86 @@
+import dominate
+from dominate.tags import meta, h3, table, tr, td, p, a, img, br
+import os
+
+
+class HTML:
+    """This HTML class allows us to save images and write texts into a single HTML file.
+
+     It consists of functions such as <add_header> (add a text header to the HTML file),
+     <add_images> (add a row of images to the HTML file), and <save> (save the HTML to the disk).
+     It is based on Python library 'dominate', a Python library for creating and manipulating HTML documents using a DOM API.
+    """
+
+    def __init__(self, web_dir, title, refresh=0):
+        """Initialize the HTML classes
+
+        Parameters:
+            web_dir (str) -- a directory that stores the webpage. HTML file will be created at <web_dir>/index.html; images will be saved at <web_dir/images/
+            title (str)   -- the webpage name
+            refresh (int) -- how often the website refresh itself; if 0; no refreshing
+        """
+        self.title = title
+        self.web_dir = web_dir
+        self.img_dir = os.path.join(self.web_dir, 'images')
+        if not os.path.exists(self.web_dir):
+            os.makedirs(self.web_dir)
+        if not os.path.exists(self.img_dir):
+            os.makedirs(self.img_dir)
+
+        self.doc = dominate.document(title=title)
+        if refresh > 0:
+            with self.doc.head:
+                meta(http_equiv="refresh", content=str(refresh))
+
+    def get_image_dir(self):
+        """Return the directory that stores images"""
+        return self.img_dir
+
+    def add_header(self, text):
+        """Insert a header to the HTML file
+
+        Parameters:
+            text (str) -- the header text
+        """
+        with self.doc:
+            h3(text)
+
+    def add_images(self, ims, txts, links, width=400):
+        """add images to the HTML file
+
+        Parameters:
+            ims (str list)   -- a list of image paths
+            txts (str list)  -- a list of image names shown on the website
+            links (str list) --  a list of hyperref links; when you click an image, it will redirect you to a new page
+        """
+        self.t = table(border=1, style="table-layout: fixed;")  # Insert a table
+        self.doc.add(self.t)
+        with self.t:
+            with tr():
+                for im, txt, link in zip(ims, txts, links):
+                    with td(style="word-wrap: break-word;", halign="center", valign="top"):
+                        with p():
+                            with a(href=os.path.join('images', link)):
+                                img(style="width:%dpx" % width, src=os.path.join('images', im))
+                            br()
+                            p(txt)
+
+    def save(self):
+        """save the current content to the HMTL file"""
+        html_file = '%s/index.html' % self.web_dir
+        f = open(html_file, 'wt')
+        f.write(self.doc.render())
+        f.close()
+
+
+if __name__ == '__main__':  # we show an example usage here.
+    html = HTML('web/', 'test_html')
+    html.add_header('hello world')
+
+    ims, txts, links = [], [], []
+    for n in range(4):
+        ims.append('image_%d.png' % n)
+        txts.append('text_%d' % n)
+        links.append('image_%d.png' % n)
+    html.add_images(ims, txts, links)
+    html.save()
diff --git a/src/face3d/util/load_mats.py b/src/face3d/util/load_mats.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ea0a7877e80035883138415c102910d896bb61
--- /dev/null
+++ b/src/face3d/util/load_mats.py
@@ -0,0 +1,120 @@
+"""This script is to load 3D face model for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from PIL import Image
+from scipy.io import loadmat, savemat
+from array import array
+import os.path as osp
+
+# load expression basis
+def LoadExpBasis(bfm_folder='BFM'):
+    n_vertex = 53215
+    Expbin = open(osp.join(bfm_folder, 'Exp_Pca.bin'), 'rb')
+    exp_dim = array('i')
+    exp_dim.fromfile(Expbin, 1)
+    expMU = array('f')
+    expPC = array('f')
+    expMU.fromfile(Expbin, 3*n_vertex)
+    expPC.fromfile(Expbin, 3*exp_dim[0]*n_vertex)
+    Expbin.close()
+
+    expPC = np.array(expPC)
+    expPC = np.reshape(expPC, [exp_dim[0], -1])
+    expPC = np.transpose(expPC)
+
+    expEV = np.loadtxt(osp.join(bfm_folder, 'std_exp.txt'))
+
+    return expPC, expEV
+
+
+# transfer original BFM09 to our face model
+def transferBFM09(bfm_folder='BFM'):
+    print('Transfer BFM09 to BFM_model_front......')
+    original_BFM = loadmat(osp.join(bfm_folder, '01_MorphableModel.mat'))
+    shapePC = original_BFM['shapePC']  # shape basis
+    shapeEV = original_BFM['shapeEV']  # corresponding eigen value
+    shapeMU = original_BFM['shapeMU']  # mean face
+    texPC = original_BFM['texPC']  # texture basis
+    texEV = original_BFM['texEV']  # eigen value
+    texMU = original_BFM['texMU']  # mean texture
+
+    expPC, expEV = LoadExpBasis(bfm_folder)
+
+    # transfer BFM09 to our face model
+
+    idBase = shapePC*np.reshape(shapeEV, [-1, 199])
+    idBase = idBase/1e5  # unify the scale to decimeter
+    idBase = idBase[:, :80]  # use only first 80 basis
+
+    exBase = expPC*np.reshape(expEV, [-1, 79])
+    exBase = exBase/1e5  # unify the scale to decimeter
+    exBase = exBase[:, :64]  # use only first 64 basis
+
+    texBase = texPC*np.reshape(texEV, [-1, 199])
+    texBase = texBase[:, :80]  # use only first 80 basis
+
+    # our face model is cropped along face landmarks and contains only 35709 vertex.
+    # original BFM09 contains 53490 vertex, and expression basis provided by Guo et al. contains 53215 vertex.
+    # thus we select corresponding vertex to get our face model.
+
+    index_exp = loadmat(osp.join(bfm_folder, 'BFM_front_idx.mat'))
+    index_exp = index_exp['idx'].astype(np.int32) - 1  # starts from 0 (to 53215)
+
+    index_shape = loadmat(osp.join(bfm_folder, 'BFM_exp_idx.mat'))
+    index_shape = index_shape['trimIndex'].astype(
+        np.int32) - 1  # starts from 0 (to 53490)
+    index_shape = index_shape[index_exp]
+
+    idBase = np.reshape(idBase, [-1, 3, 80])
+    idBase = idBase[index_shape, :, :]
+    idBase = np.reshape(idBase, [-1, 80])
+
+    texBase = np.reshape(texBase, [-1, 3, 80])
+    texBase = texBase[index_shape, :, :]
+    texBase = np.reshape(texBase, [-1, 80])
+
+    exBase = np.reshape(exBase, [-1, 3, 64])
+    exBase = exBase[index_exp, :, :]
+    exBase = np.reshape(exBase, [-1, 64])
+
+    meanshape = np.reshape(shapeMU, [-1, 3])/1e5
+    meanshape = meanshape[index_shape, :]
+    meanshape = np.reshape(meanshape, [1, -1])
+
+    meantex = np.reshape(texMU, [-1, 3])
+    meantex = meantex[index_shape, :]
+    meantex = np.reshape(meantex, [1, -1])
+
+    # other info contains triangles, region used for computing photometric loss,
+    # region used for skin texture regularization, and 68 landmarks index etc.
+    other_info = loadmat(osp.join(bfm_folder, 'facemodel_info.mat'))
+    frontmask2_idx = other_info['frontmask2_idx']
+    skinmask = other_info['skinmask']
+    keypoints = other_info['keypoints']
+    point_buf = other_info['point_buf']
+    tri = other_info['tri']
+    tri_mask2 = other_info['tri_mask2']
+
+    # save our face model
+    savemat(osp.join(bfm_folder, 'BFM_model_front.mat'), {'meanshape': meanshape, 'meantex': meantex, 'idBase': idBase, 'exBase': exBase, 'texBase': texBase,
+            'tri': tri, 'point_buf': point_buf, 'tri_mask2': tri_mask2, 'keypoints': keypoints, 'frontmask2_idx': frontmask2_idx, 'skinmask': skinmask})
+
+
+# load landmarks for standard face, which is used for image preprocessing
+def load_lm3d(bfm_folder):
+
+    Lm3D = loadmat(osp.join(bfm_folder, 'similarity_Lm3D_all.mat'))
+    Lm3D = Lm3D['lm']
+
+    # calculate 5 facial landmarks using 68 landmarks
+    lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+    Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
+        Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
+    Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
+
+    return Lm3D
+
+
+if __name__ == '__main__':
+    transferBFM09()
\ No newline at end of file
diff --git a/src/face3d/util/nvdiffrast.py b/src/face3d/util/nvdiffrast.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b345db30085de501b6718ad5b49bb5f9144dd29
--- /dev/null
+++ b/src/face3d/util/nvdiffrast.py
@@ -0,0 +1,126 @@
+"""This script is the differentiable renderer for Deep3DFaceRecon_pytorch
+    Attention, antialiasing step is missing in current version.
+"""
+import pytorch3d.ops
+import torch
+import torch.nn.functional as F
+import kornia
+from kornia.geometry.camera import pixel2cam
+import numpy as np
+from typing import List
+from scipy.io import loadmat
+from torch import nn
+
+from pytorch3d.structures import Meshes
+from pytorch3d.renderer import (
+    look_at_view_transform,
+    FoVPerspectiveCameras,
+    DirectionalLights,
+    RasterizationSettings,
+    MeshRenderer,
+    MeshRasterizer,
+    SoftPhongShader,
+    TexturesUV,
+)
+
+# def ndc_projection(x=0.1, n=1.0, f=50.0):
+#     return np.array([[n/x,    0,            0,              0],
+#                      [  0, n/-x,            0,              0],
+#                      [  0,    0, -(f+n)/(f-n), -(2*f*n)/(f-n)],
+#                      [  0,    0,           -1,              0]]).astype(np.float32)
+
+class MeshRenderer(nn.Module):
+    def __init__(self,
+                rasterize_fov,
+                znear=0.1,
+                zfar=10, 
+                rasterize_size=224):
+        super(MeshRenderer, self).__init__()
+
+        # x = np.tan(np.deg2rad(rasterize_fov * 0.5)) * znear
+        # self.ndc_proj = torch.tensor(ndc_projection(x=x, n=znear, f=zfar)).matmul(
+        #         torch.diag(torch.tensor([1., -1, -1, 1])))
+        self.rasterize_size = rasterize_size
+        self.fov = rasterize_fov
+        self.znear = znear
+        self.zfar = zfar
+
+        self.rasterizer = None
+    
+    def forward(self, vertex, tri, feat=None):
+        """
+        Return:
+            mask               -- torch.tensor, size (B, 1, H, W)
+            depth              -- torch.tensor, size (B, 1, H, W)
+            features(optional) -- torch.tensor, size (B, C, H, W) if feat is not None
+
+        Parameters:
+            vertex          -- torch.tensor, size (B, N, 3)
+            tri             -- torch.tensor, size (B, M, 3) or (M, 3), triangles
+            feat(optional)  -- torch.tensor, size (B, N ,C), features
+        """
+        device = vertex.device
+        rsize = int(self.rasterize_size)
+        # ndc_proj = self.ndc_proj.to(device)
+        # trans to homogeneous coordinates of 3d vertices, the direction of y is the same as v
+        if vertex.shape[-1] == 3:
+            vertex = torch.cat([vertex, torch.ones([*vertex.shape[:2], 1]).to(device)], dim=-1)
+            vertex[..., 0] = -vertex[..., 0]
+
+
+        # vertex_ndc = vertex @ ndc_proj.t()
+        if self.rasterizer is None:
+            self.rasterizer = MeshRasterizer()
+            print("create rasterizer on device cuda:%d"%device.index)
+        
+        # ranges = None
+        # if isinstance(tri, List) or len(tri.shape) == 3:
+        #     vum = vertex_ndc.shape[1]
+        #     fnum = torch.tensor([f.shape[0] for f in tri]).unsqueeze(1).to(device)
+        #     fstartidx = torch.cumsum(fnum, dim=0) - fnum
+        #     ranges = torch.cat([fstartidx, fnum], axis=1).type(torch.int32).cpu()
+        #     for i in range(tri.shape[0]):
+        #         tri[i] = tri[i] + i*vum
+        #     vertex_ndc = torch.cat(vertex_ndc, dim=0)
+        #     tri = torch.cat(tri, dim=0)
+
+        # for range_mode vetex: [B*N, 4], tri: [B*M, 3], for instance_mode vetex: [B, N, 4], tri: [M, 3]
+        tri = tri.type(torch.int32).contiguous()
+
+        # rasterize
+        cameras = FoVPerspectiveCameras(
+            device=device,
+            fov=self.fov,
+            znear=self.znear,
+            zfar=self.zfar,
+        )
+
+        raster_settings = RasterizationSettings(
+            image_size=rsize
+        )
+
+        # print(vertex.shape, tri.shape)
+        mesh = Meshes(vertex.contiguous()[...,:3], tri.unsqueeze(0).repeat((vertex.shape[0],1,1)))
+
+        fragments = self.rasterizer(mesh, cameras = cameras, raster_settings = raster_settings)
+        rast_out = fragments.pix_to_face.squeeze(-1)
+        depth = fragments.zbuf
+
+        # render depth
+        depth = depth.permute(0, 3, 1, 2)
+        mask = (rast_out > 0).float().unsqueeze(1)
+        depth = mask * depth
+        
+
+        image = None
+        if feat is not None:
+            attributes = feat.reshape(-1,3)[mesh.faces_packed()]
+            image = pytorch3d.ops.interpolate_face_attributes(fragments.pix_to_face,
+                                                      fragments.bary_coords,
+                                                      attributes)
+            # print(image.shape)
+            image = image.squeeze(-2).permute(0, 3, 1, 2)
+            image = mask * image
+        
+        return mask, depth, image
+
diff --git a/src/face3d/util/preprocess.py b/src/face3d/util/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e0bb54e0c82c70197c1c93e639ac284d6204f4b
--- /dev/null
+++ b/src/face3d/util/preprocess.py
@@ -0,0 +1,103 @@
+"""This script contains the image preprocessing code for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+from scipy.io import loadmat
+from PIL import Image
+import cv2
+import os
+from skimage import transform as trans
+import torch
+import warnings
+warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning) 
+warnings.filterwarnings("ignore", category=FutureWarning) 
+
+
+# calculating least square problem for image alignment
+def POS(xp, x):
+    npts = xp.shape[1]
+
+    A = np.zeros([2*npts, 8])
+
+    A[0:2*npts-1:2, 0:3] = x.transpose()
+    A[0:2*npts-1:2, 3] = 1
+
+    A[1:2*npts:2, 4:7] = x.transpose()
+    A[1:2*npts:2, 7] = 1
+
+    b = np.reshape(xp.transpose(), [2*npts, 1])
+
+    k, _, _, _ = np.linalg.lstsq(A, b)
+
+    R1 = k[0:3]
+    R2 = k[4:7]
+    sTx = k[3]
+    sTy = k[7]
+    s = (np.linalg.norm(R1) + np.linalg.norm(R2))/2
+    t = np.stack([sTx, sTy], axis=0)
+
+    return t, s
+    
+# resize and crop images for face reconstruction
+def resize_n_crop_img(img, lm, t, s, target_size=224., mask=None):
+    w0, h0 = img.size
+    w = (w0*s).astype(np.int32)
+    h = (h0*s).astype(np.int32)
+    left = (w/2 - target_size/2 + float((t[0] - w0/2)*s)).astype(np.int32)
+    right = left + target_size
+    up = (h/2 - target_size/2 + float((h0/2 - t[1])*s)).astype(np.int32)
+    below = up + target_size
+
+    img = img.resize((w, h), resample=Image.BICUBIC)
+    img = img.crop((left, up, right, below))
+
+    if mask is not None:
+        mask = mask.resize((w, h), resample=Image.BICUBIC)
+        mask = mask.crop((left, up, right, below))
+
+    lm = np.stack([lm[:, 0] - t[0] + w0/2, lm[:, 1] -
+                  t[1] + h0/2], axis=1)*s
+    lm = lm - np.reshape(
+            np.array([(w/2 - target_size/2), (h/2-target_size/2)]), [1, 2])
+
+    return img, lm, mask
+
+# utils for face reconstruction
+def extract_5p(lm):
+    lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
+    lm5p = np.stack([lm[lm_idx[0], :], np.mean(lm[lm_idx[[1, 2]], :], 0), np.mean(
+        lm[lm_idx[[3, 4]], :], 0), lm[lm_idx[5], :], lm[lm_idx[6], :]], axis=0)
+    lm5p = lm5p[[1, 2, 0, 3, 4], :]
+    return lm5p
+
+# utils for face reconstruction
+def align_img(img, lm, lm3D, mask=None, target_size=224., rescale_factor=102.):
+    """
+    Return:
+        transparams        --numpy.array  (raw_W, raw_H, scale, tx, ty)
+        img_new            --PIL.Image  (target_size, target_size, 3)
+        lm_new             --numpy.array  (68, 2), y direction is opposite to v direction
+        mask_new           --PIL.Image  (target_size, target_size)
+    
+    Parameters:
+        img                --PIL.Image  (raw_H, raw_W, 3)
+        lm                 --numpy.array  (68, 2), y direction is opposite to v direction
+        lm3D               --numpy.array  (5, 3)
+        mask               --PIL.Image  (raw_H, raw_W, 3)
+    """
+
+    w0, h0 = img.size
+    if lm.shape[0] != 5:
+        lm5p = extract_5p(lm)
+    else:
+        lm5p = lm
+
+    # calculate translation and scale factors using 5 facial landmarks and standard landmarks of a 3D face
+    t, s = POS(lm5p.transpose(), lm3D.transpose())
+    s = rescale_factor/s
+
+    # processing the image
+    img_new, lm_new, mask_new = resize_n_crop_img(img, lm, t, s, target_size=target_size, mask=mask)
+    trans_params = np.array([w0, h0, s, t[0], t[1]])
+
+    return trans_params, img_new, lm_new, mask_new
diff --git a/src/face3d/util/skin_mask.py b/src/face3d/util/skin_mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed764759038f77b35d45448b344d4347498ca427
--- /dev/null
+++ b/src/face3d/util/skin_mask.py
@@ -0,0 +1,125 @@
+"""This script is to generate skin attention mask for Deep3DFaceRecon_pytorch
+"""
+
+import math
+import numpy as np
+import os
+import cv2
+
+class GMM:
+    def __init__(self, dim, num, w, mu, cov, cov_det, cov_inv):
+        self.dim = dim # feature dimension
+        self.num = num # number of Gaussian components
+        self.w = w # weights of Gaussian components (a list of scalars)
+        self.mu= mu # mean of Gaussian components (a list of 1xdim vectors)
+        self.cov = cov # covariance matrix of Gaussian components (a list of dimxdim matrices)
+        self.cov_det = cov_det # pre-computed determinet of covariance matrices (a list of scalars)
+        self.cov_inv = cov_inv # pre-computed inverse covariance matrices (a list of dimxdim matrices)
+
+        self.factor = [0]*num
+        for i in range(self.num):
+            self.factor[i] = (2*math.pi)**(self.dim/2) * self.cov_det[i]**0.5
+        
+    def likelihood(self, data):
+        assert(data.shape[1] == self.dim)
+        N = data.shape[0]
+        lh = np.zeros(N)
+
+        for i in range(self.num):
+            data_ = data - self.mu[i]
+
+            tmp = np.matmul(data_,self.cov_inv[i]) * data_
+            tmp = np.sum(tmp,axis=1)
+            power = -0.5 * tmp
+
+            p = np.array([math.exp(power[j]) for j in range(N)])
+            p = p/self.factor[i]
+            lh += p*self.w[i]
+        
+        return lh
+
+
+def _rgb2ycbcr(rgb):
+    m = np.array([[65.481, 128.553, 24.966],
+                  [-37.797, -74.203, 112],
+                  [112, -93.786, -18.214]])
+    shape = rgb.shape
+    rgb = rgb.reshape((shape[0] * shape[1], 3))
+    ycbcr = np.dot(rgb, m.transpose() / 255.)
+    ycbcr[:, 0] += 16.
+    ycbcr[:, 1:] += 128.
+    return ycbcr.reshape(shape)
+
+
+def _bgr2ycbcr(bgr):
+    rgb = bgr[..., ::-1]
+    return _rgb2ycbcr(rgb)
+
+
+gmm_skin_w = [0.24063933, 0.16365987, 0.26034665, 0.33535415]
+gmm_skin_mu = [np.array([113.71862, 103.39613, 164.08226]),
+                np.array([150.19858, 105.18467, 155.51428]),
+                np.array([183.92976, 107.62468, 152.71820]),
+                np.array([114.90524, 113.59782, 151.38217])]
+gmm_skin_cov_det = [5692842.5, 5851930.5, 2329131., 1585971.]
+gmm_skin_cov_inv = [np.array([[0.0019472069, 0.0020450759, -0.00060243998],[0.0020450759, 0.017700525, 0.0051420014],[-0.00060243998, 0.0051420014, 0.0081308950]]),
+                    np.array([[0.0027110141, 0.0011036990, 0.0023122299],[0.0011036990, 0.010707724, 0.010742856],[0.0023122299, 0.010742856, 0.017481629]]),
+                    np.array([[0.0048026871, 0.00022935172, 0.0077668377],[0.00022935172, 0.011729696, 0.0081661865],[0.0077668377, 0.0081661865, 0.025374353]]),
+                    np.array([[0.0011989699, 0.0022453172, -0.0010748957],[0.0022453172, 0.047758564, 0.020332102],[-0.0010748957, 0.020332102, 0.024502251]])]
+
+gmm_skin = GMM(3, 4, gmm_skin_w, gmm_skin_mu, [], gmm_skin_cov_det, gmm_skin_cov_inv)
+
+gmm_nonskin_w = [0.12791070, 0.31130761, 0.34245777, 0.21832393]
+gmm_nonskin_mu = [np.array([99.200851, 112.07533, 140.20602]),
+                    np.array([110.91392, 125.52969, 130.19237]),
+                    np.array([129.75864, 129.96107, 126.96808]),
+                    np.array([112.29587, 128.85121, 129.05431])]
+gmm_nonskin_cov_det = [458703648., 6466488., 90611376., 133097.63]
+gmm_nonskin_cov_inv = [np.array([[0.00085371657, 0.00071197288, 0.00023958916],[0.00071197288, 0.0025935620, 0.00076557708],[0.00023958916, 0.00076557708, 0.0015042332]]),
+                    np.array([[0.00024650150, 0.00045542428, 0.00015019422],[0.00045542428, 0.026412144, 0.018419769],[0.00015019422, 0.018419769, 0.037497383]]),
+                    np.array([[0.00037054974, 0.00038146760, 0.00040408765],[0.00038146760, 0.0085505722, 0.0079136286],[0.00040408765, 0.0079136286, 0.010982352]]),
+                    np.array([[0.00013709733, 0.00051228428, 0.00012777430],[0.00051228428, 0.28237113, 0.10528370],[0.00012777430, 0.10528370, 0.23468947]])]
+
+gmm_nonskin = GMM(3, 4, gmm_nonskin_w, gmm_nonskin_mu, [], gmm_nonskin_cov_det, gmm_nonskin_cov_inv)
+
+prior_skin = 0.8
+prior_nonskin = 1 - prior_skin
+
+
+# calculate skin attention mask
+def skinmask(imbgr):
+    im = _bgr2ycbcr(imbgr)
+
+    data = im.reshape((-1,3))
+
+    lh_skin = gmm_skin.likelihood(data)
+    lh_nonskin = gmm_nonskin.likelihood(data)
+
+    tmp1 = prior_skin * lh_skin
+    tmp2 = prior_nonskin * lh_nonskin
+    post_skin = tmp1 / (tmp1+tmp2) # posterior probability
+
+    post_skin = post_skin.reshape((im.shape[0],im.shape[1]))
+
+    post_skin = np.round(post_skin*255)
+    post_skin = post_skin.astype(np.uint8)
+    post_skin = np.tile(np.expand_dims(post_skin,2),[1,1,3]) # reshape to H*W*3
+
+    return post_skin
+
+
+def get_skin_mask(img_path):
+    print('generating skin masks......')
+    names = [i for i in sorted(os.listdir(
+        img_path)) if 'jpg' in i or 'png' in i or 'jpeg' in i or 'PNG' in i]
+    save_path = os.path.join(img_path, 'mask')
+    if not os.path.isdir(save_path):
+        os.makedirs(save_path)
+    
+    for i in range(0, len(names)):
+        name = names[i]
+        print('%05d' % (i), ' ', name)
+        full_image_name = os.path.join(img_path, name)
+        img = cv2.imread(full_image_name).astype(np.float32)
+        skin_img = skinmask(img)
+        cv2.imwrite(os.path.join(save_path, name), skin_img.astype(np.uint8))
diff --git a/src/face3d/util/test_mean_face.txt b/src/face3d/util/test_mean_face.txt
new file mode 100644
index 0000000000000000000000000000000000000000..1637648acf5a61cbc71b317c845414bb16d0150c
--- /dev/null
+++ b/src/face3d/util/test_mean_face.txt
@@ -0,0 +1,136 @@
+-5.228591537475585938e+01
+2.078247070312500000e-01
+-5.064269638061523438e+01
+-1.315765380859375000e+01
+-4.952939224243164062e+01
+-2.592591094970703125e+01
+-4.793047332763671875e+01
+-3.832135772705078125e+01
+-4.512159729003906250e+01
+-5.059623336791992188e+01
+-3.917720794677734375e+01
+-6.043736648559570312e+01
+-2.929953765869140625e+01
+-6.861183166503906250e+01
+-1.719801330566406250e+01
+-7.572736358642578125e+01
+-1.961936950683593750e+00
+-7.862001037597656250e+01
+1.467941284179687500e+01
+-7.607844543457031250e+01
+2.744073486328125000e+01
+-6.915261840820312500e+01
+3.855677795410156250e+01
+-5.950350570678710938e+01
+4.478240966796875000e+01
+-4.867547225952148438e+01
+4.714337158203125000e+01
+-3.800830078125000000e+01
+4.940315246582031250e+01
+-2.496297454833984375e+01
+5.117234802246093750e+01
+-1.241538238525390625e+01
+5.190507507324218750e+01
+8.244247436523437500e-01
+-4.150688934326171875e+01
+2.386329650878906250e+01
+-3.570307159423828125e+01
+3.017010498046875000e+01
+-2.790358734130859375e+01
+3.212951660156250000e+01
+-1.941773223876953125e+01
+3.156523132324218750e+01
+-1.138106536865234375e+01
+2.841992187500000000e+01
+5.993263244628906250e+00
+2.895182800292968750e+01
+1.343590545654296875e+01
+3.189880371093750000e+01
+2.203153991699218750e+01
+3.302221679687500000e+01
+2.992478942871093750e+01
+3.099150085449218750e+01
+3.628388977050781250e+01
+2.765748596191406250e+01
+-1.933914184570312500e+00
+1.405374145507812500e+01
+-2.153038024902343750e+00
+5.772636413574218750e+00
+-2.270050048828125000e+00
+-2.121643066406250000e+00
+-2.218330383300781250e+00
+-1.068978118896484375e+01
+-1.187252044677734375e+01
+-1.997912597656250000e+01
+-6.879402160644531250e+00
+-2.143579864501953125e+01
+-1.227821350097656250e+00
+-2.193494415283203125e+01
+4.623237609863281250e+00
+-2.152721405029296875e+01
+9.721397399902343750e+00
+-1.953671264648437500e+01
+-3.648714447021484375e+01
+9.811126708984375000e+00
+-3.130242919921875000e+01
+1.422447967529296875e+01
+-2.212834930419921875e+01
+1.493019866943359375e+01
+-1.500880432128906250e+01
+1.073588562011718750e+01
+-2.095037078857421875e+01
+9.054298400878906250e+00
+-3.050099182128906250e+01
+8.704177856445312500e+00
+1.173237609863281250e+01
+1.054329681396484375e+01
+1.856353759765625000e+01
+1.535009765625000000e+01
+2.893331909179687500e+01
+1.451992797851562500e+01
+3.452944946289062500e+01
+1.065280151367187500e+01
+2.875990295410156250e+01
+8.654792785644531250e+00
+1.942100524902343750e+01
+9.422447204589843750e+00
+-2.204488372802734375e+01
+-3.983994293212890625e+01
+-1.324458312988281250e+01
+-3.467377471923828125e+01
+-6.749649047851562500e+00
+-3.092894744873046875e+01
+-9.183349609375000000e-01
+-3.196458435058593750e+01
+4.220649719238281250e+00
+-3.090406036376953125e+01
+1.089889526367187500e+01
+-3.497008514404296875e+01
+1.874589538574218750e+01
+-4.065438079833984375e+01
+1.124106597900390625e+01
+-4.438417816162109375e+01
+5.181709289550781250e+00
+-4.649170684814453125e+01
+-1.158607482910156250e+00
+-4.680406951904296875e+01
+-7.918922424316406250e+00
+-4.671575164794921875e+01
+-1.452505493164062500e+01
+-4.416526031494140625e+01
+-2.005007171630859375e+01
+-3.997841644287109375e+01
+-1.054919433593750000e+01
+-3.849683380126953125e+01
+-1.051826477050781250e+00
+-3.794863128662109375e+01
+6.412681579589843750e+00
+-3.804645538330078125e+01
+1.627674865722656250e+01
+-4.039697265625000000e+01
+6.373878479003906250e+00
+-4.087213897705078125e+01
+-8.551712036132812500e-01
+-4.157129669189453125e+01
+-1.014953613281250000e+01
+-4.128469085693359375e+01
diff --git a/src/face3d/util/util.py b/src/face3d/util/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..79c7517ee66c8830a73fa86ab5e5c3513f11d869
--- /dev/null
+++ b/src/face3d/util/util.py
@@ -0,0 +1,208 @@
+"""This script contains basic utilities for Deep3DFaceRecon_pytorch
+"""
+from __future__ import print_function
+import numpy as np
+import torch
+from PIL import Image
+import os
+import importlib
+import argparse
+from argparse import Namespace
+import torchvision
+
+
+def str2bool(v):
+    if isinstance(v, bool):
+        return v
+    if v.lower() in ('yes', 'true', 't', 'y', '1'):
+        return True
+    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
+        return False
+    else:
+        raise argparse.ArgumentTypeError('Boolean value expected.')
+
+
+def copyconf(default_opt, **kwargs):
+    conf = Namespace(**vars(default_opt))
+    for key in kwargs:
+        setattr(conf, key, kwargs[key])
+    return conf
+
+def genvalconf(train_opt, **kwargs):
+    conf = Namespace(**vars(train_opt))
+    attr_dict = train_opt.__dict__
+    for key, value in attr_dict.items():
+        if 'val' in key and key.split('_')[0] in attr_dict:
+            setattr(conf, key.split('_')[0], value)
+
+    for key in kwargs:
+        setattr(conf, key, kwargs[key])
+
+    return conf
+        
+def find_class_in_module(target_cls_name, module):
+    target_cls_name = target_cls_name.replace('_', '').lower()
+    clslib = importlib.import_module(module)
+    cls = None
+    for name, clsobj in clslib.__dict__.items():
+        if name.lower() == target_cls_name:
+            cls = clsobj
+
+    assert cls is not None, "In %s, there should be a class whose name matches %s in lowercase without underscore(_)" % (module, target_cls_name)
+
+    return cls
+
+
+def tensor2im(input_image, imtype=np.uint8):
+    """"Converts a Tensor array into a numpy image array.
+
+    Parameters:
+        input_image (tensor) --  the input image tensor array, range(0, 1)
+        imtype (type)        --  the desired type of the converted numpy array
+    """
+    if not isinstance(input_image, np.ndarray):
+        if isinstance(input_image, torch.Tensor):  # get the data from a variable
+            image_tensor = input_image.data
+        else:
+            return input_image
+        image_numpy = image_tensor.clamp(0.0, 1.0).cpu().float().numpy()  # convert it into a numpy array
+        if image_numpy.shape[0] == 1:  # grayscale to RGB
+            image_numpy = np.tile(image_numpy, (3, 1, 1))
+        image_numpy = np.transpose(image_numpy, (1, 2, 0)) * 255.0  # post-processing: tranpose and scaling
+    else:  # if it is a numpy array, do nothing
+        image_numpy = input_image
+    return image_numpy.astype(imtype)
+
+
+def diagnose_network(net, name='network'):
+    """Calculate and print the mean of average absolute(gradients)
+
+    Parameters:
+        net (torch network) -- Torch network
+        name (str) -- the name of the network
+    """
+    mean = 0.0
+    count = 0
+    for param in net.parameters():
+        if param.grad is not None:
+            mean += torch.mean(torch.abs(param.grad.data))
+            count += 1
+    if count > 0:
+        mean = mean / count
+    print(name)
+    print(mean)
+
+
+def save_image(image_numpy, image_path, aspect_ratio=1.0):
+    """Save a numpy image to the disk
+
+    Parameters:
+        image_numpy (numpy array) -- input numpy array
+        image_path (str)          -- the path of the image
+    """
+
+    image_pil = Image.fromarray(image_numpy)
+    h, w, _ = image_numpy.shape
+
+    if aspect_ratio is None:
+        pass
+    elif aspect_ratio > 1.0:
+        image_pil = image_pil.resize((h, int(w * aspect_ratio)), Image.BICUBIC)
+    elif aspect_ratio < 1.0:
+        image_pil = image_pil.resize((int(h / aspect_ratio), w), Image.BICUBIC)
+    image_pil.save(image_path)
+
+
+def print_numpy(x, val=True, shp=False):
+    """Print the mean, min, max, median, std, and size of a numpy array
+
+    Parameters:
+        val (bool) -- if print the values of the numpy array
+        shp (bool) -- if print the shape of the numpy array
+    """
+    x = x.astype(np.float64)
+    if shp:
+        print('shape,', x.shape)
+    if val:
+        x = x.flatten()
+        print('mean = %3.3f, min = %3.3f, max = %3.3f, median = %3.3f, std=%3.3f' % (
+            np.mean(x), np.min(x), np.max(x), np.median(x), np.std(x)))
+
+
+def mkdirs(paths):
+    """create empty directories if they don't exist
+
+    Parameters:
+        paths (str list) -- a list of directory paths
+    """
+    if isinstance(paths, list) and not isinstance(paths, str):
+        for path in paths:
+            mkdir(path)
+    else:
+        mkdir(paths)
+
+
+def mkdir(path):
+    """create a single empty directory if it didn't exist
+
+    Parameters:
+        path (str) -- a single directory path
+    """
+    if not os.path.exists(path):
+        os.makedirs(path)
+
+
+def correct_resize_label(t, size):
+    device = t.device
+    t = t.detach().cpu()
+    resized = []
+    for i in range(t.size(0)):
+        one_t = t[i, :1]
+        one_np = np.transpose(one_t.numpy().astype(np.uint8), (1, 2, 0))
+        one_np = one_np[:, :, 0]
+        one_image = Image.fromarray(one_np).resize(size, Image.NEAREST)
+        resized_t = torch.from_numpy(np.array(one_image)).long()
+        resized.append(resized_t)
+    return torch.stack(resized, dim=0).to(device)
+
+
+def correct_resize(t, size, mode=Image.BICUBIC):
+    device = t.device
+    t = t.detach().cpu()
+    resized = []
+    for i in range(t.size(0)):
+        one_t = t[i:i + 1]
+        one_image = Image.fromarray(tensor2im(one_t)).resize(size, Image.BICUBIC)
+        resized_t = torchvision.transforms.functional.to_tensor(one_image) * 2 - 1.0
+        resized.append(resized_t)
+    return torch.stack(resized, dim=0).to(device)
+
+def draw_landmarks(img, landmark, color='r', step=2):
+    """
+    Return:
+        img              -- numpy.array, (B, H, W, 3) img with landmark, RGB order, range (0, 255)
+        
+
+    Parameters:
+        img              -- numpy.array, (B, H, W, 3), RGB order, range (0, 255)
+        landmark         -- numpy.array, (B, 68, 2), y direction is opposite to v direction
+        color            -- str, 'r' or 'b' (red or blue)
+    """
+    if color =='r':
+        c = np.array([255., 0, 0])
+    else:
+        c = np.array([0, 0, 255.])
+
+    _, H, W, _ = img.shape
+    img, landmark = img.copy(), landmark.copy()
+    landmark[..., 1] = H - 1 - landmark[..., 1]
+    landmark = np.round(landmark).astype(np.int32)
+    for i in range(landmark.shape[1]):
+        x, y = landmark[:, i, 0], landmark[:, i, 1]
+        for j in range(-step, step):
+            for k in range(-step, step):
+                u = np.clip(x + j, 0, W - 1)
+                v = np.clip(y + k, 0, H - 1)
+                for m in range(landmark.shape[0]):
+                    img[m, v[m], u[m]] = c
+    return img
diff --git a/src/face3d/util/visualizer.py b/src/face3d/util/visualizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4a8b755e054a4a34d003962a723ef189726a7a0
--- /dev/null
+++ b/src/face3d/util/visualizer.py
@@ -0,0 +1,227 @@
+"""This script defines the visualizer for Deep3DFaceRecon_pytorch
+"""
+
+import numpy as np
+import os
+import sys
+import ntpath
+import time
+from . import util, html
+from subprocess import Popen, PIPE
+from torch.utils.tensorboard import SummaryWriter
+
+def save_images(webpage, visuals, image_path, aspect_ratio=1.0, width=256):
+    """Save images to the disk.
+
+    Parameters:
+        webpage (the HTML class) -- the HTML webpage class that stores these imaegs (see html.py for more details)
+        visuals (OrderedDict)    -- an ordered dictionary that stores (name, images (either tensor or numpy) ) pairs
+        image_path (str)         -- the string is used to create image paths
+        aspect_ratio (float)     -- the aspect ratio of saved images
+        width (int)              -- the images will be resized to width x width
+
+    This function will save images stored in 'visuals' to the HTML file specified by 'webpage'.
+    """
+    image_dir = webpage.get_image_dir()
+    short_path = ntpath.basename(image_path[0])
+    name = os.path.splitext(short_path)[0]
+
+    webpage.add_header(name)
+    ims, txts, links = [], [], []
+
+    for label, im_data in visuals.items():
+        im = util.tensor2im(im_data)
+        image_name = '%s/%s.png' % (label, name)
+        os.makedirs(os.path.join(image_dir, label), exist_ok=True)
+        save_path = os.path.join(image_dir, image_name)
+        util.save_image(im, save_path, aspect_ratio=aspect_ratio)
+        ims.append(image_name)
+        txts.append(label)
+        links.append(image_name)
+    webpage.add_images(ims, txts, links, width=width)
+
+
+class Visualizer():
+    """This class includes several functions that can display/save images and print/save logging information.
+
+    It uses a Python library tensprboardX for display, and a Python library 'dominate' (wrapped in 'HTML') for creating HTML files with images.
+    """
+
+    def __init__(self, opt):
+        """Initialize the Visualizer class
+
+        Parameters:
+            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+        Step 1: Cache the training/test options
+        Step 2: create a tensorboard writer
+        Step 3: create an HTML object for saveing HTML filters
+        Step 4: create a logging file to store training losses
+        """
+        self.opt = opt  # cache the option
+        self.use_html = opt.isTrain and not opt.no_html
+        self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, 'logs', opt.name))
+        self.win_size = opt.display_winsize
+        self.name = opt.name
+        self.saved = False
+        if self.use_html:  # create an HTML object at <checkpoints_dir>/web/; images will be saved under <checkpoints_dir>/web/images/
+            self.web_dir = os.path.join(opt.checkpoints_dir, opt.name, 'web')
+            self.img_dir = os.path.join(self.web_dir, 'images')
+            print('create web directory %s...' % self.web_dir)
+            util.mkdirs([self.web_dir, self.img_dir])
+        # create a logging file to store training losses
+        self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+        with open(self.log_name, "a") as log_file:
+            now = time.strftime("%c")
+            log_file.write('================ Training Loss (%s) ================\n' % now)
+
+    def reset(self):
+        """Reset the self.saved status"""
+        self.saved = False
+
+
+    def display_current_results(self, visuals, total_iters, epoch, save_result):
+        """Display current results on tensorboad; save current results to an HTML file.
+
+        Parameters:
+            visuals (OrderedDict) - - dictionary of images to display or save
+            total_iters (int) -- total iterations
+            epoch (int) - - the current epoch
+            save_result (bool) - - if save the current results to an HTML file
+        """
+        for label, image in visuals.items():
+            self.writer.add_image(label, util.tensor2im(image), total_iters, dataformats='HWC')
+
+        if self.use_html and (save_result or not self.saved):  # save images to an HTML file if they haven't been saved.
+            self.saved = True
+            # save images to the disk
+            for label, image in visuals.items():
+                image_numpy = util.tensor2im(image)
+                img_path = os.path.join(self.img_dir, 'epoch%.3d_%s.png' % (epoch, label))
+                util.save_image(image_numpy, img_path)
+
+            # update website
+            webpage = html.HTML(self.web_dir, 'Experiment name = %s' % self.name, refresh=0)
+            for n in range(epoch, 0, -1):
+                webpage.add_header('epoch [%d]' % n)
+                ims, txts, links = [], [], []
+
+                for label, image_numpy in visuals.items():
+                    image_numpy = util.tensor2im(image)
+                    img_path = 'epoch%.3d_%s.png' % (n, label)
+                    ims.append(img_path)
+                    txts.append(label)
+                    links.append(img_path)
+                webpage.add_images(ims, txts, links, width=self.win_size)
+            webpage.save()
+
+    def plot_current_losses(self, total_iters, losses):
+        # G_loss_collection = {}
+        # D_loss_collection = {}
+        # for name, value in losses.items():
+        #     if 'G' in name or 'NCE' in name or 'idt' in name:
+        #         G_loss_collection[name] = value
+        #     else:
+        #         D_loss_collection[name] = value
+        # self.writer.add_scalars('G_collec', G_loss_collection, total_iters)
+        # self.writer.add_scalars('D_collec', D_loss_collection, total_iters)
+        for name, value in losses.items():
+            self.writer.add_scalar(name, value, total_iters)
+
+    # losses: same format as |losses| of plot_current_losses
+    def print_current_losses(self, epoch, iters, losses, t_comp, t_data):
+        """print current losses on console; also save the losses to the disk
+
+        Parameters:
+            epoch (int) -- current epoch
+            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+            t_comp (float) -- computational time per data point (normalized by batch_size)
+            t_data (float) -- data loading time per data point (normalized by batch_size)
+        """
+        message = '(epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (epoch, iters, t_comp, t_data)
+        for k, v in losses.items():
+            message += '%s: %.3f ' % (k, v)
+
+        print(message)  # print the message
+        with open(self.log_name, "a") as log_file:
+            log_file.write('%s\n' % message)  # save the message
+
+
+class MyVisualizer:
+    def __init__(self, opt):
+        """Initialize the Visualizer class
+
+        Parameters:
+            opt -- stores all the experiment flags; needs to be a subclass of BaseOptions
+        Step 1: Cache the training/test options
+        Step 2: create a tensorboard writer
+        Step 3: create an HTML object for saveing HTML filters
+        Step 4: create a logging file to store training losses
+        """
+        self.opt = opt  # cache the optio
+        self.name = opt.name
+        self.img_dir = os.path.join(opt.checkpoints_dir, opt.name, 'results')
+        
+        if opt.phase != 'test':
+            self.writer = SummaryWriter(os.path.join(opt.checkpoints_dir, opt.name, 'logs'))
+            # create a logging file to store training losses
+            self.log_name = os.path.join(opt.checkpoints_dir, opt.name, 'loss_log.txt')
+            with open(self.log_name, "a") as log_file:
+                now = time.strftime("%c")
+                log_file.write('================ Training Loss (%s) ================\n' % now)
+
+
+    def display_current_results(self, visuals, total_iters, epoch, dataset='train', save_results=False, count=0, name=None,
+            add_image=True):
+        """Display current results on tensorboad; save current results to an HTML file.
+
+        Parameters:
+            visuals (OrderedDict) - - dictionary of images to display or save
+            total_iters (int) -- total iterations
+            epoch (int) - - the current epoch
+            dataset (str) - - 'train' or 'val' or 'test'
+        """
+        # if (not add_image) and (not save_results): return
+        
+        for label, image in visuals.items():
+            for i in range(image.shape[0]):
+                image_numpy = util.tensor2im(image[i])
+                if add_image:
+                    self.writer.add_image(label + '%s_%02d'%(dataset, i + count),
+                            image_numpy, total_iters, dataformats='HWC')
+
+                if save_results:
+                    save_path = os.path.join(self.img_dir, dataset, 'epoch_%s_%06d'%(epoch, total_iters))
+                    if not os.path.isdir(save_path):
+                        os.makedirs(save_path)
+
+                    if name is not None:
+                        img_path = os.path.join(save_path, '%s.png' % name)
+                    else:
+                        img_path = os.path.join(save_path, '%s_%03d.png' % (label, i + count))
+                    util.save_image(image_numpy, img_path)
+
+
+    def plot_current_losses(self, total_iters, losses, dataset='train'):
+        for name, value in losses.items():
+            self.writer.add_scalar(name + '/%s'%dataset, value, total_iters)
+
+    # losses: same format as |losses| of plot_current_losses
+    def print_current_losses(self, epoch, iters, losses, t_comp, t_data, dataset='train'):
+        """print current losses on console; also save the losses to the disk
+
+        Parameters:
+            epoch (int) -- current epoch
+            iters (int) -- current training iteration during this epoch (reset to 0 at the end of every epoch)
+            losses (OrderedDict) -- training losses stored in the format of (name, float) pairs
+            t_comp (float) -- computational time per data point (normalized by batch_size)
+            t_data (float) -- data loading time per data point (normalized by batch_size)
+        """
+        message = '(dataset: %s, epoch: %d, iters: %d, time: %.3f, data: %.3f) ' % (
+            dataset, epoch, iters, t_comp, t_data)
+        for k, v in losses.items():
+            message += '%s: %.3f ' % (k, v)
+
+        print(message)  # print the message
+        with open(self.log_name, "a") as log_file:
+            log_file.write('%s\n' % message)  # save the message
diff --git a/src/face3d/visualize.py b/src/face3d/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..b92f935840bd1647b7b88eeccf4e0e51acebf216
--- /dev/null
+++ b/src/face3d/visualize.py
@@ -0,0 +1,48 @@
+# check the sync of 3dmm feature and the audio
+import cv2
+import numpy as np
+from src.face3d.models.bfm import ParametricFaceModel
+from src.face3d.models.facerecon_model import FaceReconModel
+import torch
+import subprocess, platform
+import scipy.io as scio
+from tqdm import tqdm 
+
+# draft
+def gen_composed_video(args, device, first_frame_coeff, coeff_path, audio_path, save_path, exp_dim=64):
+    
+    coeff_first = scio.loadmat(first_frame_coeff)['full_3dmm']
+
+    coeff_pred = scio.loadmat(coeff_path)['coeff_3dmm']
+
+    coeff_full = np.repeat(coeff_first, coeff_pred.shape[0], axis=0) # 257
+
+    coeff_full[:, 80:144] = coeff_pred[:, 0:64]
+    coeff_full[:, 224:227]  = coeff_pred[:, 64:67] # 3 dim translation
+    coeff_full[:, 254:]  = coeff_pred[:, 67:] # 3 dim translation
+
+    tmp_video_path = '/tmp/face3dtmp.mp4'
+
+    facemodel = FaceReconModel(args)
+    
+    video = cv2.VideoWriter(tmp_video_path, cv2.VideoWriter_fourcc(*'mp4v'), 25, (224, 224))
+
+    for k in tqdm(range(coeff_pred.shape[0]), 'face3d rendering:'):
+        cur_coeff_full = torch.tensor(coeff_full[k:k+1], device=device)
+
+        facemodel.forward(cur_coeff_full, device)
+
+        predicted_landmark = facemodel.pred_lm # TODO.
+        predicted_landmark = predicted_landmark.cpu().numpy().squeeze()
+
+        rendered_img = facemodel.pred_face
+        rendered_img = 255. * rendered_img.cpu().numpy().squeeze().transpose(1,2,0)
+        out_img = rendered_img[:, :, :3].astype(np.uint8)
+
+        video.write(np.uint8(out_img[:,:,::-1]))
+
+    video.release()
+
+    command = 'ffmpeg -v quiet -y -i {} -i {} -strict -2 -q:v 1 {}'.format(audio_path, tmp_video_path, save_path)
+    subprocess.call(command, shell=platform.system() != 'Windows')
+
diff --git a/src/facerender/__pycache__/animate.cpython-38.pyc b/src/facerender/__pycache__/animate.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b68daeb902b6fdc904d1ed39f5c19f4704d6c4b9
Binary files /dev/null and b/src/facerender/__pycache__/animate.cpython-38.pyc differ
diff --git a/src/facerender/animate.py b/src/facerender/animate.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e76f85bfa48606e5269e1c66d8ca66b3fe879e7
--- /dev/null
+++ b/src/facerender/animate.py
@@ -0,0 +1,220 @@
+import os
+import cv2
+import yaml
+import numpy as np
+import warnings
+from skimage import img_as_ubyte
+
+warnings.filterwarnings('ignore')
+
+
+import imageio
+import torch
+import torchvision
+
+
+from src.facerender.modules.keypoint_detector import HEEstimator, KPDetector
+from src.facerender.modules.mapping import MappingNet
+from src.facerender.modules.generator import OcclusionAwareGenerator, OcclusionAwareSPADEGenerator
+from src.facerender.modules.make_animation import make_animation 
+
+from pydub import AudioSegment 
+from src.utils.face_enhancer import enhancer as face_enhancer
+from src.utils.paste_pic import paste_pic
+from src.utils.videoio import save_video_with_watermark
+
+
+class AnimateFromCoeff():
+
+    def __init__(self, free_view_checkpoint, mapping_checkpoint,
+                   config_path, device):
+
+        with open(config_path) as f:
+            config = yaml.safe_load(f)
+
+        generator = OcclusionAwareSPADEGenerator(**config['model_params']['generator_params'],
+                                                    **config['model_params']['common_params'])
+        kp_extractor = KPDetector(**config['model_params']['kp_detector_params'],
+                                    **config['model_params']['common_params'])
+        he_estimator = HEEstimator(**config['model_params']['he_estimator_params'],
+                               **config['model_params']['common_params'])
+        mapping = MappingNet(**config['model_params']['mapping_params'])
+
+
+        generator.to(device)
+        kp_extractor.to(device)
+        he_estimator.to(device)
+        mapping.to(device)
+        for param in generator.parameters():
+            param.requires_grad = False
+        for param in kp_extractor.parameters():
+            param.requires_grad = False 
+        for param in he_estimator.parameters():
+            param.requires_grad = False
+        for param in mapping.parameters():
+            param.requires_grad = False
+
+        if free_view_checkpoint is not None:
+            self.load_cpk_facevid2vid(free_view_checkpoint, kp_detector=kp_extractor, generator=generator, he_estimator=he_estimator)
+        else:
+            raise AttributeError("Checkpoint should be specified for video head pose estimator.")
+
+        if  mapping_checkpoint is not None:
+            self.load_cpk_mapping(mapping_checkpoint, mapping=mapping)
+        else:
+            raise AttributeError("Checkpoint should be specified for video head pose estimator.") 
+
+        self.kp_extractor = kp_extractor
+        self.generator = generator
+        self.he_estimator = he_estimator
+        self.mapping = mapping
+
+        self.kp_extractor.eval()
+        self.generator.eval()
+        self.he_estimator.eval()
+        self.mapping.eval()
+         
+        self.device = device
+    
+    def load_cpk_facevid2vid(self, checkpoint_path, generator=None, discriminator=None, 
+                        kp_detector=None, he_estimator=None, optimizer_generator=None, 
+                        optimizer_discriminator=None, optimizer_kp_detector=None, 
+                        optimizer_he_estimator=None, device="cpu"):
+        checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+        if generator is not None:
+            generator.load_state_dict(checkpoint['generator'])
+        if kp_detector is not None:
+            kp_detector.load_state_dict(checkpoint['kp_detector'])
+        if he_estimator is not None:
+            he_estimator.load_state_dict(checkpoint['he_estimator'])
+        if discriminator is not None:
+            try:
+               discriminator.load_state_dict(checkpoint['discriminator'])
+            except:
+               print ('No discriminator in the state-dict. Dicriminator will be randomly initialized')
+        if optimizer_generator is not None:
+            optimizer_generator.load_state_dict(checkpoint['optimizer_generator'])
+        if optimizer_discriminator is not None:
+            try:
+                optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+            except RuntimeError as e:
+                print ('No discriminator optimizer in the state-dict. Optimizer will be not initialized')
+        if optimizer_kp_detector is not None:
+            optimizer_kp_detector.load_state_dict(checkpoint['optimizer_kp_detector'])
+        if optimizer_he_estimator is not None:
+            optimizer_he_estimator.load_state_dict(checkpoint['optimizer_he_estimator'])
+
+        return checkpoint['epoch']
+    
+    def load_cpk_mapping(self, checkpoint_path, mapping=None, discriminator=None,
+                 optimizer_mapping=None, optimizer_discriminator=None, device='cpu'):
+        checkpoint = torch.load(checkpoint_path,  map_location=torch.device(device))
+        if mapping is not None:
+            mapping.load_state_dict(checkpoint['mapping'])
+        if discriminator is not None:
+            discriminator.load_state_dict(checkpoint['discriminator'])
+        if optimizer_mapping is not None:
+            optimizer_mapping.load_state_dict(checkpoint['optimizer_mapping'])
+        if optimizer_discriminator is not None:
+            optimizer_discriminator.load_state_dict(checkpoint['optimizer_discriminator'])
+
+        return checkpoint['epoch']
+
+    def generate(self, x, video_save_dir, pic_path, crop_info, enhancer=None, background_enhancer=None, preprocess='crop'):
+
+        source_image=x['source_image'].type(torch.FloatTensor)
+        source_semantics=x['source_semantics'].type(torch.FloatTensor)
+        target_semantics=x['target_semantics_list'].type(torch.FloatTensor) 
+        source_image=source_image.to(self.device)
+        source_semantics=source_semantics.to(self.device)
+        target_semantics=target_semantics.to(self.device)
+        if 'yaw_c_seq' in x:
+            yaw_c_seq = x['yaw_c_seq'].type(torch.FloatTensor)
+            yaw_c_seq = x['yaw_c_seq'].to(self.device)
+        else:
+            yaw_c_seq = None
+        if 'pitch_c_seq' in x:
+            pitch_c_seq = x['pitch_c_seq'].type(torch.FloatTensor)
+            pitch_c_seq = x['pitch_c_seq'].to(self.device)
+        else:
+            pitch_c_seq = None
+        if 'roll_c_seq' in x:
+            roll_c_seq = x['roll_c_seq'].type(torch.FloatTensor) 
+            roll_c_seq = x['roll_c_seq'].to(self.device)
+        else:
+            roll_c_seq = None
+
+        frame_num = x['frame_num']
+
+        predictions_video = make_animation(source_image, source_semantics, target_semantics,
+                                        self.generator, self.kp_extractor, self.he_estimator, self.mapping, 
+                                        yaw_c_seq, pitch_c_seq, roll_c_seq, use_exp = True)
+
+        predictions_video = predictions_video.reshape((-1,)+predictions_video.shape[2:])
+        predictions_video = predictions_video[:frame_num]
+
+        video = []
+        for idx in range(predictions_video.shape[0]):
+            image = predictions_video[idx]
+            image = np.transpose(image.data.cpu().numpy(), [1, 2, 0]).astype(np.float32)
+            video.append(image)
+        result = img_as_ubyte(video)
+
+        ### the generated video is 256x256, so we  keep the aspect ratio, 
+        original_size = crop_info[0]
+        if original_size:
+            result = [ cv2.resize(result_i,(256, int(256.0 * original_size[1]/original_size[0]) )) for result_i in result ]
+        
+        video_name = x['video_name']  + '.mp4'
+        path = os.path.join(video_save_dir, 'temp_'+video_name)
+        
+        imageio.mimsave(path, result,  fps=float(25))
+
+        av_path = os.path.join(video_save_dir, video_name)
+        return_path = av_path 
+        
+        audio_path =  x['audio_path'] 
+        audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+        new_audio_path = os.path.join(video_save_dir, audio_name+'.wav')
+        print('new_audio_path',new_audio_path)
+        start_time = 0
+        # cog will not keep the .mp3 filename
+        sound = AudioSegment.from_file(audio_path)
+        frames = frame_num 
+        end_time = start_time + frames*1/25*1000
+        word1=sound.set_frame_rate(16000)
+        word = word1[start_time:end_time]
+        word.export(new_audio_path, format="wav")
+
+        base64_video,temp_file_path = save_video_with_watermark(path, new_audio_path, av_path, watermark= False)
+        print(f'The generated video is named {video_name} in {video_save_dir}')
+
+        if preprocess.lower() == 'full':
+            # only add watermark to the full image.
+            video_name_full = x['video_name']  + '_full.mp4'
+            full_video_path = os.path.join(video_save_dir, video_name_full)
+            return_path = full_video_path
+            paste_pic(path, pic_path, crop_info, new_audio_path, full_video_path)
+            print(f'The generated video is named {video_save_dir}/{video_name_full}') 
+        else:
+            full_video_path = av_path 
+
+        #### paste back then enhancers
+        if enhancer:
+            video_name_enhancer = x['video_name']  + '_enhanced.mp4'
+            enhanced_path = os.path.join(video_save_dir, 'temp_'+video_name_enhancer)
+            av_path_enhancer = os.path.join(video_save_dir, video_name_enhancer) 
+            return_path = av_path_enhancer
+            enhanced_images = face_enhancer(path, method=enhancer, bg_upsampler=background_enhancer)
+
+            imageio.mimsave(enhanced_path, enhanced_images, fps=float(25))
+            
+            base64_video,temp_file_path = save_video_with_watermark(enhanced_path, new_audio_path, av_path_enhancer, watermark= False)
+            print(f'The generated video is named {video_save_dir}/{video_name_enhancer}')
+            os.remove(enhanced_path)
+
+        os.remove(path)
+        os.remove(new_audio_path)
+
+        return return_path,base64_video,temp_file_path
+
diff --git a/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..972ccceb8d9fdb35164211898b8a24a2aec50e4c
Binary files /dev/null and b/src/facerender/modules/__pycache__/dense_motion.cpython-38.pyc differ
diff --git a/src/facerender/modules/__pycache__/generator.cpython-38.pyc b/src/facerender/modules/__pycache__/generator.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..47feef75a504ae1a284910ac93060d12bf86354e
Binary files /dev/null and b/src/facerender/modules/__pycache__/generator.cpython-38.pyc differ
diff --git a/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d59ecee29cd3b5eec0b6da2ea0134fb2f05a05cd
Binary files /dev/null and b/src/facerender/modules/__pycache__/keypoint_detector.cpython-38.pyc differ
diff --git a/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..f70b28334cde05e60ccc187f97772bd2b0b325e6
Binary files /dev/null and b/src/facerender/modules/__pycache__/make_animation.cpython-38.pyc differ
diff --git a/src/facerender/modules/__pycache__/mapping.cpython-38.pyc b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c119a0c869aabb47a17944faeba65a2af9a60e08
Binary files /dev/null and b/src/facerender/modules/__pycache__/mapping.cpython-38.pyc differ
diff --git a/src/facerender/modules/__pycache__/util.cpython-38.pyc b/src/facerender/modules/__pycache__/util.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..aa0377744b0e052c55b6deb2c798f0fa9e129dd4
Binary files /dev/null and b/src/facerender/modules/__pycache__/util.cpython-38.pyc differ
diff --git a/src/facerender/modules/dense_motion.py b/src/facerender/modules/dense_motion.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c30417870e79bc005ea47a8f383c3aa406df563
--- /dev/null
+++ b/src/facerender/modules/dense_motion.py
@@ -0,0 +1,121 @@
+from torch import nn
+import torch.nn.functional as F
+import torch
+from src.facerender.modules.util import Hourglass, make_coordinate_grid, kp2gaussian
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+
+class DenseMotionNetwork(nn.Module):
+    """
+    Module that predicting a dense motion from sparse motion representation given by kp_source and kp_driving
+    """
+
+    def __init__(self, block_expansion, num_blocks, max_features, num_kp, feature_channel, reshape_depth, compress,
+                 estimate_occlusion_map=False):
+        super(DenseMotionNetwork, self).__init__()
+        # self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(feature_channel+1), max_features=max_features, num_blocks=num_blocks)
+        self.hourglass = Hourglass(block_expansion=block_expansion, in_features=(num_kp+1)*(compress+1), max_features=max_features, num_blocks=num_blocks)
+
+        self.mask = nn.Conv3d(self.hourglass.out_filters, num_kp + 1, kernel_size=7, padding=3)
+
+        self.compress = nn.Conv3d(feature_channel, compress, kernel_size=1)
+        self.norm = BatchNorm3d(compress, affine=True)
+
+        if estimate_occlusion_map:
+            # self.occlusion = nn.Conv2d(reshape_channel*reshape_depth, 1, kernel_size=7, padding=3)
+            self.occlusion = nn.Conv2d(self.hourglass.out_filters*reshape_depth, 1, kernel_size=7, padding=3)
+        else:
+            self.occlusion = None
+
+        self.num_kp = num_kp
+
+
+    def create_sparse_motions(self, feature, kp_driving, kp_source):
+        bs, _, d, h, w = feature.shape
+        identity_grid = make_coordinate_grid((d, h, w), type=kp_source['value'].type())
+        identity_grid = identity_grid.view(1, 1, d, h, w, 3)
+        coordinate_grid = identity_grid - kp_driving['value'].view(bs, self.num_kp, 1, 1, 1, 3)
+        
+        # if 'jacobian' in kp_driving:
+        if 'jacobian' in kp_driving and kp_driving['jacobian'] is not None:
+            jacobian = torch.matmul(kp_source['jacobian'], torch.inverse(kp_driving['jacobian']))
+            jacobian = jacobian.unsqueeze(-3).unsqueeze(-3).unsqueeze(-3)
+            jacobian = jacobian.repeat(1, 1, d, h, w, 1, 1)
+            coordinate_grid = torch.matmul(jacobian, coordinate_grid.unsqueeze(-1))
+            coordinate_grid = coordinate_grid.squeeze(-1)                  
+
+
+        driving_to_source = coordinate_grid + kp_source['value'].view(bs, self.num_kp, 1, 1, 1, 3)    # (bs, num_kp, d, h, w, 3)
+
+        #adding background feature
+        identity_grid = identity_grid.repeat(bs, 1, 1, 1, 1, 1)
+        sparse_motions = torch.cat([identity_grid, driving_to_source], dim=1)                #bs num_kp+1 d h w 3
+        
+        # sparse_motions = driving_to_source
+
+        return sparse_motions
+
+    def create_deformed_feature(self, feature, sparse_motions):
+        bs, _, d, h, w = feature.shape
+        feature_repeat = feature.unsqueeze(1).unsqueeze(1).repeat(1, self.num_kp+1, 1, 1, 1, 1, 1)      # (bs, num_kp+1, 1, c, d, h, w)
+        feature_repeat = feature_repeat.view(bs * (self.num_kp+1), -1, d, h, w)                         # (bs*(num_kp+1), c, d, h, w)
+        sparse_motions = sparse_motions.view((bs * (self.num_kp+1), d, h, w, -1))                       # (bs*(num_kp+1), d, h, w, 3) !!!!
+        sparse_deformed = F.grid_sample(feature_repeat, sparse_motions)
+        sparse_deformed = sparse_deformed.view((bs, self.num_kp+1, -1, d, h, w))                        # (bs, num_kp+1, c, d, h, w)
+        return sparse_deformed
+
+    def create_heatmap_representations(self, feature, kp_driving, kp_source):
+        spatial_size = feature.shape[3:]
+        gaussian_driving = kp2gaussian(kp_driving, spatial_size=spatial_size, kp_variance=0.01)
+        gaussian_source = kp2gaussian(kp_source, spatial_size=spatial_size, kp_variance=0.01)
+        heatmap = gaussian_driving - gaussian_source
+
+        # adding background feature
+        zeros = torch.zeros(heatmap.shape[0], 1, spatial_size[0], spatial_size[1], spatial_size[2]).type(heatmap.type())
+        heatmap = torch.cat([zeros, heatmap], dim=1)
+        heatmap = heatmap.unsqueeze(2)         # (bs, num_kp+1, 1, d, h, w)
+        return heatmap
+
+    def forward(self, feature, kp_driving, kp_source):
+        bs, _, d, h, w = feature.shape
+
+        feature = self.compress(feature)
+        feature = self.norm(feature)
+        feature = F.relu(feature)
+
+        out_dict = dict()
+        sparse_motion = self.create_sparse_motions(feature, kp_driving, kp_source)
+        deformed_feature = self.create_deformed_feature(feature, sparse_motion)
+
+        heatmap = self.create_heatmap_representations(deformed_feature, kp_driving, kp_source)
+
+        input_ = torch.cat([heatmap, deformed_feature], dim=2)
+        input_ = input_.view(bs, -1, d, h, w)
+
+        # input = deformed_feature.view(bs, -1, d, h, w)      # (bs, num_kp+1 * c, d, h, w)
+
+        prediction = self.hourglass(input_)
+
+
+        mask = self.mask(prediction)
+        mask = F.softmax(mask, dim=1)
+        out_dict['mask'] = mask
+        mask = mask.unsqueeze(2)                                   # (bs, num_kp+1, 1, d, h, w)
+        
+        zeros_mask = torch.zeros_like(mask)   
+        mask = torch.where(mask < 1e-3, zeros_mask, mask) 
+
+        sparse_motion = sparse_motion.permute(0, 1, 5, 2, 3, 4)    # (bs, num_kp+1, 3, d, h, w)
+        deformation = (sparse_motion * mask).sum(dim=1)            # (bs, 3, d, h, w)
+        deformation = deformation.permute(0, 2, 3, 4, 1)           # (bs, d, h, w, 3)
+
+        out_dict['deformation'] = deformation
+
+        if self.occlusion:
+            bs, c, d, h, w = prediction.shape
+            prediction = prediction.view(bs, -1, h, w)
+            occlusion_map = torch.sigmoid(self.occlusion(prediction))
+            out_dict['occlusion_map'] = occlusion_map
+
+        return out_dict
diff --git a/src/facerender/modules/discriminator.py b/src/facerender/modules/discriminator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0a2b460d2175a958d7b230b7e5233d7d7c7f92
--- /dev/null
+++ b/src/facerender/modules/discriminator.py
@@ -0,0 +1,90 @@
+from torch import nn
+import torch.nn.functional as F
+from facerender.modules.util import kp2gaussian
+import torch
+
+
+class DownBlock2d(nn.Module):
+    """
+    Simple block for processing video (encoder).
+    """
+
+    def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
+        super(DownBlock2d, self).__init__()
+        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
+
+        if sn:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+        if norm:
+            self.norm = nn.InstanceNorm2d(out_features, affine=True)
+        else:
+            self.norm = None
+        self.pool = pool
+
+    def forward(self, x):
+        out = x
+        out = self.conv(out)
+        if self.norm:
+            out = self.norm(out)
+        out = F.leaky_relu(out, 0.2)
+        if self.pool:
+            out = F.avg_pool2d(out, (2, 2))
+        return out
+
+
+class Discriminator(nn.Module):
+    """
+    Discriminator similar to Pix2Pix
+    """
+
+    def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
+                 sn=False, **kwargs):
+        super(Discriminator, self).__init__()
+
+        down_blocks = []
+        for i in range(num_blocks):
+            down_blocks.append(
+                DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
+                            min(max_features, block_expansion * (2 ** (i + 1))),
+                            norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
+
+        self.down_blocks = nn.ModuleList(down_blocks)
+        self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
+        if sn:
+            self.conv = nn.utils.spectral_norm(self.conv)
+
+    def forward(self, x):
+        feature_maps = []
+        out = x
+
+        for down_block in self.down_blocks:
+            feature_maps.append(down_block(out))
+            out = feature_maps[-1]
+        prediction_map = self.conv(out)
+
+        return feature_maps, prediction_map
+
+
+class MultiScaleDiscriminator(nn.Module):
+    """
+    Multi-scale (scale) discriminator
+    """
+
+    def __init__(self, scales=(), **kwargs):
+        super(MultiScaleDiscriminator, self).__init__()
+        self.scales = scales
+        discs = {}
+        for scale in scales:
+            discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
+        self.discs = nn.ModuleDict(discs)
+
+    def forward(self, x):
+        out_dict = {}
+        for scale, disc in self.discs.items():
+            scale = str(scale).replace('-', '.')
+            key = 'prediction_' + scale
+            feature_maps, prediction_map = disc(x[key])
+            out_dict['feature_maps_' + scale] = feature_maps
+            out_dict['prediction_map_' + scale] = prediction_map
+        return out_dict
diff --git a/src/facerender/modules/generator.py b/src/facerender/modules/generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b94dde7a37c5ddf0f74dd0317a5db3507ab0729
--- /dev/null
+++ b/src/facerender/modules/generator.py
@@ -0,0 +1,255 @@
+import torch
+from torch import nn
+import torch.nn.functional as F
+from src.facerender.modules.util import ResBlock2d, SameBlock2d, UpBlock2d, DownBlock2d, ResBlock3d, SPADEResnetBlock
+from src.facerender.modules.dense_motion import DenseMotionNetwork
+
+
+class OcclusionAwareGenerator(nn.Module):
+    """
+    Generator follows NVIDIA architecture.
+    """
+
+    def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+                 num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+        super(OcclusionAwareGenerator, self).__init__()
+
+        if dense_motion_params is not None:
+            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+                                                           estimate_occlusion_map=estimate_occlusion_map,
+                                                           **dense_motion_params)
+        else:
+            self.dense_motion_network = None
+
+        self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(7, 7), padding=(3, 3))
+
+        down_blocks = []
+        for i in range(num_down_blocks):
+            in_features = min(max_features, block_expansion * (2 ** i))
+            out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+        self.down_blocks = nn.ModuleList(down_blocks)
+
+        self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+        self.reshape_channel = reshape_channel
+        self.reshape_depth = reshape_depth
+
+        self.resblocks_3d = torch.nn.Sequential()
+        for i in range(num_resblocks):
+            self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+        out_features = block_expansion * (2 ** (num_down_blocks))
+        self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+        self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+        self.resblocks_2d = torch.nn.Sequential()
+        for i in range(num_resblocks):
+            self.resblocks_2d.add_module('2dr' + str(i), ResBlock2d(out_features, kernel_size=3, padding=1))
+
+        up_blocks = []
+        for i in range(num_down_blocks):
+            in_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i)))
+            out_features = max(block_expansion, block_expansion * (2 ** (num_down_blocks - i - 1)))
+            up_blocks.append(UpBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+        self.up_blocks = nn.ModuleList(up_blocks)
+
+        self.final = nn.Conv2d(block_expansion, image_channel, kernel_size=(7, 7), padding=(3, 3))
+        self.estimate_occlusion_map = estimate_occlusion_map
+        self.image_channel = image_channel
+
+    def deform_input(self, inp, deformation):
+        _, d_old, h_old, w_old, _ = deformation.shape
+        _, _, d, h, w = inp.shape
+        if d_old != d or h_old != h or w_old != w:
+            deformation = deformation.permute(0, 4, 1, 2, 3)
+            deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+            deformation = deformation.permute(0, 2, 3, 4, 1)
+        return F.grid_sample(inp, deformation)
+
+    def forward(self, source_image, kp_driving, kp_source):
+        # Encoding (downsampling) part
+        out = self.first(source_image)
+        for i in range(len(self.down_blocks)):
+            out = self.down_blocks[i](out)
+        out = self.second(out)
+        bs, c, h, w = out.shape
+        # print(out.shape)
+        feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) 
+        feature_3d = self.resblocks_3d(feature_3d)
+
+        # Transforming feature representation according to deformation and occlusion
+        output_dict = {}
+        if self.dense_motion_network is not None:
+            dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+                                                     kp_source=kp_source)
+            output_dict['mask'] = dense_motion['mask']
+
+            if 'occlusion_map' in dense_motion:
+                occlusion_map = dense_motion['occlusion_map']
+                output_dict['occlusion_map'] = occlusion_map
+            else:
+                occlusion_map = None
+            deformation = dense_motion['deformation']
+            out = self.deform_input(feature_3d, deformation)
+
+            bs, c, d, h, w = out.shape
+            out = out.view(bs, c*d, h, w)
+            out = self.third(out)
+            out = self.fourth(out)
+
+            if occlusion_map is not None:
+                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+                out = out * occlusion_map
+
+            # output_dict["deformed"] = self.deform_input(source_image, deformation)  # 3d deformation cannot deform 2d image
+
+        # Decoding part
+        out = self.resblocks_2d(out)
+        for i in range(len(self.up_blocks)):
+            out = self.up_blocks[i](out)
+        out = self.final(out)
+        out = F.sigmoid(out)
+
+        output_dict["prediction"] = out
+
+        return output_dict
+
+
+class SPADEDecoder(nn.Module):
+    def __init__(self):
+        super().__init__()
+        ic = 256
+        oc = 64
+        norm_G = 'spadespectralinstance'
+        label_nc = 256
+        
+        self.fc = nn.Conv2d(ic, 2 * ic, 3, padding=1)
+        self.G_middle_0 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.G_middle_1 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.G_middle_2 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.G_middle_3 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.G_middle_4 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.G_middle_5 = SPADEResnetBlock(2 * ic, 2 * ic, norm_G, label_nc)
+        self.up_0 = SPADEResnetBlock(2 * ic, ic, norm_G, label_nc)
+        self.up_1 = SPADEResnetBlock(ic, oc, norm_G, label_nc)
+        self.conv_img = nn.Conv2d(oc, 3, 3, padding=1)
+        self.up = nn.Upsample(scale_factor=2)
+        
+    def forward(self, feature):
+        seg = feature
+        x = self.fc(feature)
+        x = self.G_middle_0(x, seg)
+        x = self.G_middle_1(x, seg)
+        x = self.G_middle_2(x, seg)
+        x = self.G_middle_3(x, seg)
+        x = self.G_middle_4(x, seg)
+        x = self.G_middle_5(x, seg)
+        x = self.up(x)                
+        x = self.up_0(x, seg)         # 256, 128, 128
+        x = self.up(x)                
+        x = self.up_1(x, seg)         # 64, 256, 256
+
+        x = self.conv_img(F.leaky_relu(x, 2e-1))
+        # x = torch.tanh(x)
+        x = F.sigmoid(x)
+        
+        return x
+
+
+class OcclusionAwareSPADEGenerator(nn.Module):
+
+    def __init__(self, image_channel, feature_channel, num_kp, block_expansion, max_features, num_down_blocks, reshape_channel, reshape_depth,
+                 num_resblocks, estimate_occlusion_map=False, dense_motion_params=None, estimate_jacobian=False):
+        super(OcclusionAwareSPADEGenerator, self).__init__()
+
+        if dense_motion_params is not None:
+            self.dense_motion_network = DenseMotionNetwork(num_kp=num_kp, feature_channel=feature_channel,
+                                                           estimate_occlusion_map=estimate_occlusion_map,
+                                                           **dense_motion_params)
+        else:
+            self.dense_motion_network = None
+
+        self.first = SameBlock2d(image_channel, block_expansion, kernel_size=(3, 3), padding=(1, 1))
+
+        down_blocks = []
+        for i in range(num_down_blocks):
+            in_features = min(max_features, block_expansion * (2 ** i))
+            out_features = min(max_features, block_expansion * (2 ** (i + 1)))
+            down_blocks.append(DownBlock2d(in_features, out_features, kernel_size=(3, 3), padding=(1, 1)))
+        self.down_blocks = nn.ModuleList(down_blocks)
+
+        self.second = nn.Conv2d(in_channels=out_features, out_channels=max_features, kernel_size=1, stride=1)
+
+        self.reshape_channel = reshape_channel
+        self.reshape_depth = reshape_depth
+
+        self.resblocks_3d = torch.nn.Sequential()
+        for i in range(num_resblocks):
+            self.resblocks_3d.add_module('3dr' + str(i), ResBlock3d(reshape_channel, kernel_size=3, padding=1))
+
+        out_features = block_expansion * (2 ** (num_down_blocks))
+        self.third = SameBlock2d(max_features, out_features, kernel_size=(3, 3), padding=(1, 1), lrelu=True)
+        self.fourth = nn.Conv2d(in_channels=out_features, out_channels=out_features, kernel_size=1, stride=1)
+
+        self.estimate_occlusion_map = estimate_occlusion_map
+        self.image_channel = image_channel
+
+        self.decoder = SPADEDecoder()
+
+    def deform_input(self, inp, deformation):
+        _, d_old, h_old, w_old, _ = deformation.shape
+        _, _, d, h, w = inp.shape
+        if d_old != d or h_old != h or w_old != w:
+            deformation = deformation.permute(0, 4, 1, 2, 3)
+            deformation = F.interpolate(deformation, size=(d, h, w), mode='trilinear')
+            deformation = deformation.permute(0, 2, 3, 4, 1)
+        return F.grid_sample(inp, deformation)
+
+    def forward(self, source_image, kp_driving, kp_source):
+        # Encoding (downsampling) part
+        out = self.first(source_image)
+        for i in range(len(self.down_blocks)):
+            out = self.down_blocks[i](out)
+        out = self.second(out)
+        bs, c, h, w = out.shape
+        # print(out.shape)
+        feature_3d = out.view(bs, self.reshape_channel, self.reshape_depth, h ,w) 
+        feature_3d = self.resblocks_3d(feature_3d)
+
+        # Transforming feature representation according to deformation and occlusion
+        output_dict = {}
+        if self.dense_motion_network is not None:
+            dense_motion = self.dense_motion_network(feature=feature_3d, kp_driving=kp_driving,
+                                                     kp_source=kp_source)
+            output_dict['mask'] = dense_motion['mask']
+
+            # import pdb; pdb.set_trace()
+
+            if 'occlusion_map' in dense_motion:
+                occlusion_map = dense_motion['occlusion_map']
+                output_dict['occlusion_map'] = occlusion_map
+            else:
+                occlusion_map = None
+            deformation = dense_motion['deformation']
+            out = self.deform_input(feature_3d, deformation)
+
+            bs, c, d, h, w = out.shape
+            out = out.view(bs, c*d, h, w)
+            out = self.third(out)
+            out = self.fourth(out)
+
+            # occlusion_map = torch.where(occlusion_map < 0.95, 0, occlusion_map)
+            
+            if occlusion_map is not None:
+                if out.shape[2] != occlusion_map.shape[2] or out.shape[3] != occlusion_map.shape[3]:
+                    occlusion_map = F.interpolate(occlusion_map, size=out.shape[2:], mode='bilinear')
+                out = out * occlusion_map
+
+        # Decoding part
+        out = self.decoder(out)
+
+        output_dict["prediction"] = out
+        
+        return output_dict
\ No newline at end of file
diff --git a/src/facerender/modules/keypoint_detector.py b/src/facerender/modules/keypoint_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..e56800c7b1e94bb3cbf97200cd3f059ce9d29cf3
--- /dev/null
+++ b/src/facerender/modules/keypoint_detector.py
@@ -0,0 +1,179 @@
+from torch import nn
+import torch
+import torch.nn.functional as F
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.modules.util import KPHourglass, make_coordinate_grid, AntiAliasInterpolation2d, ResBottleneck
+
+
+class KPDetector(nn.Module):
+    """
+    Detecting canonical keypoints. Return keypoint position and jacobian near each keypoint.
+    """
+
+    def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, reshape_channel, reshape_depth,
+                 num_blocks, temperature, estimate_jacobian=False, scale_factor=1, single_jacobian_map=False):
+        super(KPDetector, self).__init__()
+
+        self.predictor = KPHourglass(block_expansion, in_features=image_channel,
+                                     max_features=max_features,  reshape_features=reshape_channel, reshape_depth=reshape_depth, num_blocks=num_blocks)
+
+        # self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=7, padding=3)
+        self.kp = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=num_kp, kernel_size=3, padding=1)
+
+        if estimate_jacobian:
+            self.num_jacobian_maps = 1 if single_jacobian_map else num_kp
+            # self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=7, padding=3)
+            self.jacobian = nn.Conv3d(in_channels=self.predictor.out_filters, out_channels=9 * self.num_jacobian_maps, kernel_size=3, padding=1)
+            '''
+            initial as:
+            [[1 0 0]
+             [0 1 0]
+             [0 0 1]]
+            '''
+            self.jacobian.weight.data.zero_()
+            self.jacobian.bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0, 0, 0, 1] * self.num_jacobian_maps, dtype=torch.float))
+        else:
+            self.jacobian = None
+
+        self.temperature = temperature
+        self.scale_factor = scale_factor
+        if self.scale_factor != 1:
+            self.down = AntiAliasInterpolation2d(image_channel, self.scale_factor)
+
+    def gaussian2kp(self, heatmap):
+        """
+        Extract the mean from a heatmap
+        """
+        shape = heatmap.shape
+        heatmap = heatmap.unsqueeze(-1)
+        grid = make_coordinate_grid(shape[2:], heatmap.type()).unsqueeze_(0).unsqueeze_(0)
+        value = (heatmap * grid).sum(dim=(2, 3, 4))
+        kp = {'value': value}
+
+        return kp
+
+    def forward(self, x):
+        if self.scale_factor != 1:
+            x = self.down(x)
+
+        feature_map = self.predictor(x)
+        prediction = self.kp(feature_map)
+
+        final_shape = prediction.shape
+        heatmap = prediction.view(final_shape[0], final_shape[1], -1)
+        heatmap = F.softmax(heatmap / self.temperature, dim=2)
+        heatmap = heatmap.view(*final_shape)
+
+        out = self.gaussian2kp(heatmap)
+
+        if self.jacobian is not None:
+            jacobian_map = self.jacobian(feature_map)
+            jacobian_map = jacobian_map.reshape(final_shape[0], self.num_jacobian_maps, 9, final_shape[2],
+                                                final_shape[3], final_shape[4])
+            heatmap = heatmap.unsqueeze(2)
+
+            jacobian = heatmap * jacobian_map
+            jacobian = jacobian.view(final_shape[0], final_shape[1], 9, -1)
+            jacobian = jacobian.sum(dim=-1)
+            jacobian = jacobian.view(jacobian.shape[0], jacobian.shape[1], 3, 3)
+            out['jacobian'] = jacobian
+
+        return out
+
+
+class HEEstimator(nn.Module):
+    """
+    Estimating head pose and expression.
+    """
+
+    def __init__(self, block_expansion, feature_channel, num_kp, image_channel, max_features, num_bins=66, estimate_jacobian=True):
+        super(HEEstimator, self).__init__()
+
+        self.conv1 = nn.Conv2d(in_channels=image_channel, out_channels=block_expansion, kernel_size=7, padding=3, stride=2)
+        self.norm1 = BatchNorm2d(block_expansion, affine=True)
+        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+        self.conv2 = nn.Conv2d(in_channels=block_expansion, out_channels=256, kernel_size=1)
+        self.norm2 = BatchNorm2d(256, affine=True)
+
+        self.block1 = nn.Sequential()
+        for i in range(3):
+            self.block1.add_module('b1_'+ str(i), ResBottleneck(in_features=256, stride=1))
+
+        self.conv3 = nn.Conv2d(in_channels=256, out_channels=512, kernel_size=1)
+        self.norm3 = BatchNorm2d(512, affine=True)
+        self.block2 = ResBottleneck(in_features=512, stride=2)
+
+        self.block3 = nn.Sequential()
+        for i in range(3):
+            self.block3.add_module('b3_'+ str(i), ResBottleneck(in_features=512, stride=1))
+
+        self.conv4 = nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=1)
+        self.norm4 = BatchNorm2d(1024, affine=True)
+        self.block4 = ResBottleneck(in_features=1024, stride=2)
+
+        self.block5 = nn.Sequential()
+        for i in range(5):
+            self.block5.add_module('b5_'+ str(i), ResBottleneck(in_features=1024, stride=1))
+
+        self.conv5 = nn.Conv2d(in_channels=1024, out_channels=2048, kernel_size=1)
+        self.norm5 = BatchNorm2d(2048, affine=True)
+        self.block6 = ResBottleneck(in_features=2048, stride=2)
+
+        self.block7 = nn.Sequential()
+        for i in range(2):
+            self.block7.add_module('b7_'+ str(i), ResBottleneck(in_features=2048, stride=1))
+
+        self.fc_roll = nn.Linear(2048, num_bins)
+        self.fc_pitch = nn.Linear(2048, num_bins)
+        self.fc_yaw = nn.Linear(2048, num_bins)
+
+        self.fc_t = nn.Linear(2048, 3)
+
+        self.fc_exp = nn.Linear(2048, 3*num_kp)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.norm1(out)
+        out = F.relu(out)
+        out = self.maxpool(out)
+
+        out = self.conv2(out)
+        out = self.norm2(out)
+        out = F.relu(out)
+
+        out = self.block1(out)
+
+        out = self.conv3(out)
+        out = self.norm3(out)
+        out = F.relu(out)
+        out = self.block2(out)
+
+        out = self.block3(out)
+
+        out = self.conv4(out)
+        out = self.norm4(out)
+        out = F.relu(out)
+        out = self.block4(out)
+
+        out = self.block5(out)
+
+        out = self.conv5(out)
+        out = self.norm5(out)
+        out = F.relu(out)
+        out = self.block6(out)
+
+        out = self.block7(out)
+
+        out = F.adaptive_avg_pool2d(out, 1)
+        out = out.view(out.shape[0], -1)
+
+        yaw = self.fc_roll(out)
+        pitch = self.fc_pitch(out)
+        roll = self.fc_yaw(out)
+        t = self.fc_t(out)
+        exp = self.fc_exp(out)
+
+        return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp}
+
diff --git a/src/facerender/modules/make_animation.py b/src/facerender/modules/make_animation.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9f2d65e6e62b9a15bbb289c3e24dc8c99f2ab69
--- /dev/null
+++ b/src/facerender/modules/make_animation.py
@@ -0,0 +1,170 @@
+from scipy.spatial import ConvexHull
+import torch
+import torch.nn.functional as F
+import numpy as np
+from tqdm import tqdm 
+
+def normalize_kp(kp_source, kp_driving, kp_driving_initial, adapt_movement_scale=False,
+                 use_relative_movement=False, use_relative_jacobian=False):
+    if adapt_movement_scale:
+        source_area = ConvexHull(kp_source['value'][0].data.cpu().numpy()).volume
+        driving_area = ConvexHull(kp_driving_initial['value'][0].data.cpu().numpy()).volume
+        adapt_movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
+    else:
+        adapt_movement_scale = 1
+
+    kp_new = {k: v for k, v in kp_driving.items()}
+
+    if use_relative_movement:
+        kp_value_diff = (kp_driving['value'] - kp_driving_initial['value'])
+        kp_value_diff *= adapt_movement_scale
+        kp_new['value'] = kp_value_diff + kp_source['value']
+
+        if use_relative_jacobian:
+            jacobian_diff = torch.matmul(kp_driving['jacobian'], torch.inverse(kp_driving_initial['jacobian']))
+            kp_new['jacobian'] = torch.matmul(jacobian_diff, kp_source['jacobian'])
+
+    return kp_new
+
+def headpose_pred_to_degree(pred):
+    device = pred.device
+    idx_tensor = [idx for idx in range(66)]
+    idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+    pred = F.softmax(pred)
+    degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+    return degree
+
+def get_rotation_matrix(yaw, pitch, roll):
+    yaw = yaw / 180 * 3.14
+    pitch = pitch / 180 * 3.14
+    roll = roll / 180 * 3.14
+
+    roll = roll.unsqueeze(1)
+    pitch = pitch.unsqueeze(1)
+    yaw = yaw.unsqueeze(1)
+
+    pitch_mat = torch.cat([torch.ones_like(pitch), torch.zeros_like(pitch), torch.zeros_like(pitch), 
+                          torch.zeros_like(pitch), torch.cos(pitch), -torch.sin(pitch),
+                          torch.zeros_like(pitch), torch.sin(pitch), torch.cos(pitch)], dim=1)
+    pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+    yaw_mat = torch.cat([torch.cos(yaw), torch.zeros_like(yaw), torch.sin(yaw), 
+                           torch.zeros_like(yaw), torch.ones_like(yaw), torch.zeros_like(yaw),
+                           -torch.sin(yaw), torch.zeros_like(yaw), torch.cos(yaw)], dim=1)
+    yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+    roll_mat = torch.cat([torch.cos(roll), -torch.sin(roll), torch.zeros_like(roll),  
+                         torch.sin(roll), torch.cos(roll), torch.zeros_like(roll),
+                         torch.zeros_like(roll), torch.zeros_like(roll), torch.ones_like(roll)], dim=1)
+    roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+    rot_mat = torch.einsum('bij,bjk,bkm->bim', pitch_mat, yaw_mat, roll_mat)
+
+    return rot_mat
+
+def keypoint_transformation(kp_canonical, he, wo_exp=False):
+    kp = kp_canonical['value']    # (bs, k, 3) 
+    yaw, pitch, roll= he['yaw'], he['pitch'], he['roll']      
+    yaw = headpose_pred_to_degree(yaw) 
+    pitch = headpose_pred_to_degree(pitch)
+    roll = headpose_pred_to_degree(roll)
+
+    if 'yaw_in' in he:
+        yaw = he['yaw_in']
+    if 'pitch_in' in he:
+        pitch = he['pitch_in']
+    if 'roll_in' in he:
+        roll = he['roll_in']
+
+    rot_mat = get_rotation_matrix(yaw, pitch, roll)    # (bs, 3, 3)
+
+    t, exp = he['t'], he['exp']
+    if wo_exp:
+        exp =  exp*0  
+    
+    # keypoint rotation
+    kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+    # keypoint translation
+    t[:, 0] = t[:, 0]*0
+    t[:, 2] = t[:, 2]*0
+    t = t.unsqueeze(1).repeat(1, kp.shape[1], 1)
+    kp_t = kp_rotated + t
+
+    # add expression deviation 
+    exp = exp.view(exp.shape[0], -1, 3)
+    kp_transformed = kp_t + exp
+
+    return {'value': kp_transformed}
+
+
+
+def make_animation(source_image, source_semantics, target_semantics,
+                            generator, kp_detector, he_estimator, mapping, 
+                            yaw_c_seq=None, pitch_c_seq=None, roll_c_seq=None,
+                            use_exp=True):
+    with torch.no_grad():
+        predictions = []
+
+        kp_canonical = kp_detector(source_image)
+        he_source = mapping(source_semantics)
+        kp_source = keypoint_transformation(kp_canonical, he_source)
+    
+        for frame_idx in tqdm(range(target_semantics.shape[1]), 'Face Renderer:'):
+            target_semantics_frame = target_semantics[:, frame_idx]
+            he_driving = mapping(target_semantics_frame)
+            if yaw_c_seq is not None:
+                he_driving['yaw_in'] = yaw_c_seq[:, frame_idx]
+            if pitch_c_seq is not None:
+                he_driving['pitch_in'] = pitch_c_seq[:, frame_idx] 
+            if roll_c_seq is not None:
+                he_driving['roll_in'] = roll_c_seq[:, frame_idx] 
+            
+            kp_driving = keypoint_transformation(kp_canonical, he_driving)
+                
+            #kp_norm = normalize_kp(kp_source=kp_source, kp_driving=kp_driving,
+                                   #kp_driving_initial=kp_driving_initial)
+            kp_norm = kp_driving
+            out = generator(source_image, kp_source=kp_source, kp_driving=kp_norm)
+            '''
+            source_image_new = out['prediction'].squeeze(1)
+            kp_canonical_new =  kp_detector(source_image_new)
+            he_source_new = he_estimator(source_image_new) 
+            kp_source_new = keypoint_transformation(kp_canonical_new, he_source_new, wo_exp=True)
+            kp_driving_new = keypoint_transformation(kp_canonical_new, he_driving, wo_exp=True)
+            out = generator(source_image_new, kp_source=kp_source_new, kp_driving=kp_driving_new)
+            '''
+            predictions.append(out['prediction'])
+        predictions_ts = torch.stack(predictions, dim=1)
+    return predictions_ts
+
+class AnimateModel(torch.nn.Module):
+    """
+    Merge all generator related updates into single model for better multi-gpu usage
+    """
+
+    def __init__(self, generator, kp_extractor, mapping):
+        super(AnimateModel, self).__init__()
+        self.kp_extractor = kp_extractor
+        self.generator = generator
+        self.mapping = mapping
+
+        self.kp_extractor.eval()
+        self.generator.eval()
+        self.mapping.eval()
+
+    def forward(self, x):
+        
+        source_image = x['source_image']
+        source_semantics = x['source_semantics']
+        target_semantics = x['target_semantics']
+        yaw_c_seq = x['yaw_c_seq']
+        pitch_c_seq = x['pitch_c_seq']
+        roll_c_seq = x['roll_c_seq']
+
+        predictions_video = make_animation(source_image, source_semantics, target_semantics,
+                                        self.generator, self.kp_extractor,
+                                        self.mapping, use_exp = True,
+                                        yaw_c_seq=yaw_c_seq, pitch_c_seq=pitch_c_seq, roll_c_seq=roll_c_seq)
+        
+        return predictions_video
\ No newline at end of file
diff --git a/src/facerender/modules/mapping.py b/src/facerender/modules/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ac98dd9e177b949f71f8f47029b66d67ece05b4
--- /dev/null
+++ b/src/facerender/modules/mapping.py
@@ -0,0 +1,47 @@
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class MappingNet(nn.Module):
+    def __init__(self, coeff_nc, descriptor_nc, layer, num_kp, num_bins):
+        super( MappingNet, self).__init__()
+
+        self.layer = layer
+        nonlinearity = nn.LeakyReLU(0.1)
+
+        self.first = nn.Sequential(
+            torch.nn.Conv1d(coeff_nc, descriptor_nc, kernel_size=7, padding=0, bias=True))
+
+        for i in range(layer):
+            net = nn.Sequential(nonlinearity,
+                torch.nn.Conv1d(descriptor_nc, descriptor_nc, kernel_size=3, padding=0, dilation=3))
+            setattr(self, 'encoder' + str(i), net)   
+
+        self.pooling = nn.AdaptiveAvgPool1d(1)
+        self.output_nc = descriptor_nc
+
+        self.fc_roll = nn.Linear(descriptor_nc, num_bins)
+        self.fc_pitch = nn.Linear(descriptor_nc, num_bins)
+        self.fc_yaw = nn.Linear(descriptor_nc, num_bins)
+        self.fc_t = nn.Linear(descriptor_nc, 3)
+        self.fc_exp = nn.Linear(descriptor_nc, 3*num_kp)
+
+    def forward(self, input_3dmm):
+        out = self.first(input_3dmm)
+        for i in range(self.layer):
+            model = getattr(self, 'encoder' + str(i))
+            out = model(out) + out[:,:,3:-3]
+        out = self.pooling(out)
+        out = out.view(out.shape[0], -1)
+        #print('out:', out.shape)
+
+        yaw = self.fc_yaw(out)
+        pitch = self.fc_pitch(out)
+        roll = self.fc_roll(out)
+        t = self.fc_t(out)
+        exp = self.fc_exp(out)
+
+        return {'yaw': yaw, 'pitch': pitch, 'roll': roll, 't': t, 'exp': exp} 
\ No newline at end of file
diff --git a/src/facerender/modules/util.py b/src/facerender/modules/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3bfb1f26427b491f032ca9952db41cdeb793d70
--- /dev/null
+++ b/src/facerender/modules/util.py
@@ -0,0 +1,564 @@
+from torch import nn
+
+import torch.nn.functional as F
+import torch
+
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm2d as BatchNorm2d
+from src.facerender.sync_batchnorm import SynchronizedBatchNorm3d as BatchNorm3d
+
+import torch.nn.utils.spectral_norm as spectral_norm
+
+
+def kp2gaussian(kp, spatial_size, kp_variance):
+    """
+    Transform a keypoint into gaussian like representation
+    """
+    mean = kp['value']
+
+    coordinate_grid = make_coordinate_grid(spatial_size, mean.type())
+    number_of_leading_dimensions = len(mean.shape) - 1
+    shape = (1,) * number_of_leading_dimensions + coordinate_grid.shape
+    coordinate_grid = coordinate_grid.view(*shape)
+    repeats = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 1)
+    coordinate_grid = coordinate_grid.repeat(*repeats)
+
+    # Preprocess kp shape
+    shape = mean.shape[:number_of_leading_dimensions] + (1, 1, 1, 3)
+    mean = mean.view(*shape)
+
+    mean_sub = (coordinate_grid - mean)
+
+    out = torch.exp(-0.5 * (mean_sub ** 2).sum(-1) / kp_variance)
+
+    return out
+
+def make_coordinate_grid_2d(spatial_size, type):
+    """
+    Create a meshgrid [-1,1] x [-1,1] of given spatial_size.
+    """
+    h, w = spatial_size
+    x = torch.arange(w).type(type)
+    y = torch.arange(h).type(type)
+
+    x = (2 * (x / (w - 1)) - 1)
+    y = (2 * (y / (h - 1)) - 1)
+
+    yy = y.view(-1, 1).repeat(1, w)
+    xx = x.view(1, -1).repeat(h, 1)
+
+    meshed = torch.cat([xx.unsqueeze_(2), yy.unsqueeze_(2)], 2)
+
+    return meshed
+
+
+def make_coordinate_grid(spatial_size, type):
+    d, h, w = spatial_size
+    x = torch.arange(w).type(type)
+    y = torch.arange(h).type(type)
+    z = torch.arange(d).type(type)
+
+    x = (2 * (x / (w - 1)) - 1)
+    y = (2 * (y / (h - 1)) - 1)
+    z = (2 * (z / (d - 1)) - 1)
+   
+    yy = y.view(1, -1, 1).repeat(d, 1, w)
+    xx = x.view(1, 1, -1).repeat(d, h, 1)
+    zz = z.view(-1, 1, 1).repeat(1, h, w)
+
+    meshed = torch.cat([xx.unsqueeze_(3), yy.unsqueeze_(3), zz.unsqueeze_(3)], 3)
+
+    return meshed
+
+
+class ResBottleneck(nn.Module):
+    def __init__(self, in_features, stride):
+        super(ResBottleneck, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features//4, kernel_size=1)
+        self.conv2 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features//4, kernel_size=3, padding=1, stride=stride)
+        self.conv3 = nn.Conv2d(in_channels=in_features//4, out_channels=in_features, kernel_size=1)
+        self.norm1 = BatchNorm2d(in_features//4, affine=True)
+        self.norm2 = BatchNorm2d(in_features//4, affine=True)
+        self.norm3 = BatchNorm2d(in_features, affine=True)
+
+        self.stride = stride
+        if self.stride != 1:
+            self.skip = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=1, stride=stride)
+            self.norm4 = BatchNorm2d(in_features, affine=True)
+
+    def forward(self, x):
+        out = self.conv1(x)
+        out = self.norm1(out)
+        out = F.relu(out)
+        out = self.conv2(out)
+        out = self.norm2(out)
+        out = F.relu(out)
+        out = self.conv3(out)
+        out = self.norm3(out)
+        if self.stride != 1:
+            x = self.skip(x)
+            x = self.norm4(x)
+        out += x
+        out = F.relu(out)
+        return out
+
+
+class ResBlock2d(nn.Module):
+    """
+    Res block, preserve spatial resolution.
+    """
+
+    def __init__(self, in_features, kernel_size, padding):
+        super(ResBlock2d, self).__init__()
+        self.conv1 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+                               padding=padding)
+        self.conv2 = nn.Conv2d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+                               padding=padding)
+        self.norm1 = BatchNorm2d(in_features, affine=True)
+        self.norm2 = BatchNorm2d(in_features, affine=True)
+
+    def forward(self, x):
+        out = self.norm1(x)
+        out = F.relu(out)
+        out = self.conv1(out)
+        out = self.norm2(out)
+        out = F.relu(out)
+        out = self.conv2(out)
+        out += x
+        return out
+
+
+class ResBlock3d(nn.Module):
+    """
+    Res block, preserve spatial resolution.
+    """
+
+    def __init__(self, in_features, kernel_size, padding):
+        super(ResBlock3d, self).__init__()
+        self.conv1 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+                               padding=padding)
+        self.conv2 = nn.Conv3d(in_channels=in_features, out_channels=in_features, kernel_size=kernel_size,
+                               padding=padding)
+        self.norm1 = BatchNorm3d(in_features, affine=True)
+        self.norm2 = BatchNorm3d(in_features, affine=True)
+
+    def forward(self, x):
+        out = self.norm1(x)
+        out = F.relu(out)
+        out = self.conv1(out)
+        out = self.norm2(out)
+        out = F.relu(out)
+        out = self.conv2(out)
+        out += x
+        return out
+
+
+class UpBlock2d(nn.Module):
+    """
+    Upsampling block for use in decoder.
+    """
+
+    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+        super(UpBlock2d, self).__init__()
+
+        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+                              padding=padding, groups=groups)
+        self.norm = BatchNorm2d(out_features, affine=True)
+
+    def forward(self, x):
+        out = F.interpolate(x, scale_factor=2)
+        out = self.conv(out)
+        out = self.norm(out)
+        out = F.relu(out)
+        return out
+
+class UpBlock3d(nn.Module):
+    """
+    Upsampling block for use in decoder.
+    """
+
+    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+        super(UpBlock3d, self).__init__()
+
+        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+                              padding=padding, groups=groups)
+        self.norm = BatchNorm3d(out_features, affine=True)
+
+    def forward(self, x):
+        # out = F.interpolate(x, scale_factor=(1, 2, 2), mode='trilinear')
+        out = F.interpolate(x, scale_factor=(1, 2, 2))
+        out = self.conv(out)
+        out = self.norm(out)
+        out = F.relu(out)
+        return out
+
+
+class DownBlock2d(nn.Module):
+    """
+    Downsampling block for use in encoder.
+    """
+
+    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+        super(DownBlock2d, self).__init__()
+        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+                              padding=padding, groups=groups)
+        self.norm = BatchNorm2d(out_features, affine=True)
+        self.pool = nn.AvgPool2d(kernel_size=(2, 2))
+
+    def forward(self, x):
+        out = self.conv(x)
+        out = self.norm(out)
+        out = F.relu(out)
+        out = self.pool(out)
+        return out
+
+
+class DownBlock3d(nn.Module):
+    """
+    Downsampling block for use in encoder.
+    """
+
+    def __init__(self, in_features, out_features, kernel_size=3, padding=1, groups=1):
+        super(DownBlock3d, self).__init__()
+        '''
+        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+                              padding=padding, groups=groups, stride=(1, 2, 2))
+        '''
+        self.conv = nn.Conv3d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size,
+                              padding=padding, groups=groups)
+        self.norm = BatchNorm3d(out_features, affine=True)
+        self.pool = nn.AvgPool3d(kernel_size=(1, 2, 2))
+
+    def forward(self, x):
+        out = self.conv(x)
+        out = self.norm(out)
+        out = F.relu(out)
+        out = self.pool(out)
+        return out
+
+
+class SameBlock2d(nn.Module):
+    """
+    Simple block, preserve spatial resolution.
+    """
+
+    def __init__(self, in_features, out_features, groups=1, kernel_size=3, padding=1, lrelu=False):
+        super(SameBlock2d, self).__init__()
+        self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features,
+                              kernel_size=kernel_size, padding=padding, groups=groups)
+        self.norm = BatchNorm2d(out_features, affine=True)
+        if lrelu:
+            self.ac = nn.LeakyReLU()
+        else:
+            self.ac = nn.ReLU()
+
+    def forward(self, x):
+        out = self.conv(x)
+        out = self.norm(out)
+        out = self.ac(out)
+        return out
+
+
+class Encoder(nn.Module):
+    """
+    Hourglass Encoder
+    """
+
+    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+        super(Encoder, self).__init__()
+
+        down_blocks = []
+        for i in range(num_blocks):
+            down_blocks.append(DownBlock3d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+                                           min(max_features, block_expansion * (2 ** (i + 1))),
+                                           kernel_size=3, padding=1))
+        self.down_blocks = nn.ModuleList(down_blocks)
+
+    def forward(self, x):
+        outs = [x]
+        for down_block in self.down_blocks:
+            outs.append(down_block(outs[-1]))
+        return outs
+
+
+class Decoder(nn.Module):
+    """
+    Hourglass Decoder
+    """
+
+    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+        super(Decoder, self).__init__()
+
+        up_blocks = []
+
+        for i in range(num_blocks)[::-1]:
+            in_filters = (1 if i == num_blocks - 1 else 2) * min(max_features, block_expansion * (2 ** (i + 1)))
+            out_filters = min(max_features, block_expansion * (2 ** i))
+            up_blocks.append(UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+        self.up_blocks = nn.ModuleList(up_blocks)
+        # self.out_filters = block_expansion
+        self.out_filters = block_expansion + in_features
+
+        self.conv = nn.Conv3d(in_channels=self.out_filters, out_channels=self.out_filters, kernel_size=3, padding=1)
+        self.norm = BatchNorm3d(self.out_filters, affine=True)
+
+    def forward(self, x):
+        out = x.pop()
+        # for up_block in self.up_blocks[:-1]:
+        for up_block in self.up_blocks:
+            out = up_block(out)
+            skip = x.pop()
+            out = torch.cat([out, skip], dim=1)
+        # out = self.up_blocks[-1](out)
+        out = self.conv(out)
+        out = self.norm(out)
+        out = F.relu(out)
+        return out
+
+
+class Hourglass(nn.Module):
+    """
+    Hourglass architecture.
+    """
+
+    def __init__(self, block_expansion, in_features, num_blocks=3, max_features=256):
+        super(Hourglass, self).__init__()
+        self.encoder = Encoder(block_expansion, in_features, num_blocks, max_features)
+        self.decoder = Decoder(block_expansion, in_features, num_blocks, max_features)
+        self.out_filters = self.decoder.out_filters
+
+    def forward(self, x):
+        return self.decoder(self.encoder(x))
+
+
+class KPHourglass(nn.Module):
+    """
+    Hourglass architecture.
+    """ 
+
+    def __init__(self, block_expansion, in_features, reshape_features, reshape_depth, num_blocks=3, max_features=256):
+        super(KPHourglass, self).__init__()
+        
+        self.down_blocks = nn.Sequential()
+        for i in range(num_blocks):
+            self.down_blocks.add_module('down'+ str(i), DownBlock2d(in_features if i == 0 else min(max_features, block_expansion * (2 ** i)),
+                                                                   min(max_features, block_expansion * (2 ** (i + 1))),
+                                                                   kernel_size=3, padding=1))
+
+        in_filters = min(max_features, block_expansion * (2 ** num_blocks))
+        self.conv = nn.Conv2d(in_channels=in_filters, out_channels=reshape_features, kernel_size=1)
+
+        self.up_blocks = nn.Sequential()
+        for i in range(num_blocks):
+            in_filters = min(max_features, block_expansion * (2 ** (num_blocks - i)))
+            out_filters = min(max_features, block_expansion * (2 ** (num_blocks - i - 1)))
+            self.up_blocks.add_module('up'+ str(i), UpBlock3d(in_filters, out_filters, kernel_size=3, padding=1))
+
+        self.reshape_depth = reshape_depth
+        self.out_filters = out_filters
+
+    def forward(self, x):
+        out = self.down_blocks(x)
+        out = self.conv(out)
+        bs, c, h, w = out.shape
+        out = out.view(bs, c//self.reshape_depth, self.reshape_depth, h, w)
+        out = self.up_blocks(out)
+
+        return out
+        
+
+
+class AntiAliasInterpolation2d(nn.Module):
+    """
+    Band-limited downsampling, for better preservation of the input signal.
+    """
+    def __init__(self, channels, scale):
+        super(AntiAliasInterpolation2d, self).__init__()
+        sigma = (1 / scale - 1) / 2
+        kernel_size = 2 * round(sigma * 4) + 1
+        self.ka = kernel_size // 2
+        self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
+
+        kernel_size = [kernel_size, kernel_size]
+        sigma = [sigma, sigma]
+        # The gaussian kernel is the product of the
+        # gaussian function of each dimension.
+        kernel = 1
+        meshgrids = torch.meshgrid(
+            [
+                torch.arange(size, dtype=torch.float32)
+                for size in kernel_size
+                ]
+        )
+        for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
+            mean = (size - 1) / 2
+            kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
+
+        # Make sure sum of values in gaussian kernel equals 1.
+        kernel = kernel / torch.sum(kernel)
+        # Reshape to depthwise convolutional weight
+        kernel = kernel.view(1, 1, *kernel.size())
+        kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
+
+        self.register_buffer('weight', kernel)
+        self.groups = channels
+        self.scale = scale
+        inv_scale = 1 / scale
+        self.int_inv_scale = int(inv_scale)
+
+    def forward(self, input):
+        if self.scale == 1.0:
+            return input
+
+        out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
+        out = F.conv2d(out, weight=self.weight, groups=self.groups)
+        out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
+
+        return out
+
+
+class SPADE(nn.Module):
+    def __init__(self, norm_nc, label_nc):
+        super().__init__()
+
+        self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
+        nhidden = 128
+
+        self.mlp_shared = nn.Sequential(
+            nn.Conv2d(label_nc, nhidden, kernel_size=3, padding=1),
+            nn.ReLU())
+        self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+        self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=3, padding=1)
+
+    def forward(self, x, segmap):
+        normalized = self.param_free_norm(x)
+        segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
+        actv = self.mlp_shared(segmap)
+        gamma = self.mlp_gamma(actv)
+        beta = self.mlp_beta(actv)
+        out = normalized * (1 + gamma) + beta
+        return out
+    
+
+class SPADEResnetBlock(nn.Module):
+    def __init__(self, fin, fout, norm_G, label_nc, use_se=False, dilation=1):
+        super().__init__()
+        # Attributes
+        self.learned_shortcut = (fin != fout)
+        fmiddle = min(fin, fout)
+        self.use_se = use_se
+        # create conv layers
+        self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=dilation, dilation=dilation)
+        self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=dilation, dilation=dilation)
+        if self.learned_shortcut:
+            self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
+        # apply spectral norm if specified
+        if 'spectral' in norm_G:
+            self.conv_0 = spectral_norm(self.conv_0)
+            self.conv_1 = spectral_norm(self.conv_1)
+            if self.learned_shortcut:
+                self.conv_s = spectral_norm(self.conv_s)
+        # define normalization layers
+        self.norm_0 = SPADE(fin, label_nc)
+        self.norm_1 = SPADE(fmiddle, label_nc)
+        if self.learned_shortcut:
+            self.norm_s = SPADE(fin, label_nc)
+
+    def forward(self, x, seg1):
+        x_s = self.shortcut(x, seg1)
+        dx = self.conv_0(self.actvn(self.norm_0(x, seg1)))
+        dx = self.conv_1(self.actvn(self.norm_1(dx, seg1)))
+        out = x_s + dx
+        return out
+
+    def shortcut(self, x, seg1):
+        if self.learned_shortcut:
+            x_s = self.conv_s(self.norm_s(x, seg1))
+        else:
+            x_s = x
+        return x_s
+
+    def actvn(self, x):
+        return F.leaky_relu(x, 2e-1)
+
+class audio2image(nn.Module):
+    def __init__(self, generator, kp_extractor, he_estimator_video, he_estimator_audio, train_params):
+        super().__init__()
+        # Attributes
+        self.generator = generator
+        self.kp_extractor = kp_extractor
+        self.he_estimator_video = he_estimator_video
+        self.he_estimator_audio = he_estimator_audio
+        self.train_params = train_params
+
+    def headpose_pred_to_degree(self, pred):
+        device = pred.device
+        idx_tensor = [idx for idx in range(66)]
+        idx_tensor = torch.FloatTensor(idx_tensor).to(device)
+        pred = F.softmax(pred)
+        degree = torch.sum(pred*idx_tensor, 1) * 3 - 99
+
+        return degree
+    
+    def get_rotation_matrix(self, yaw, pitch, roll):
+        yaw = yaw / 180 * 3.14
+        pitch = pitch / 180 * 3.14
+        roll = roll / 180 * 3.14
+
+        roll = roll.unsqueeze(1)
+        pitch = pitch.unsqueeze(1)
+        yaw = yaw.unsqueeze(1)
+
+        roll_mat = torch.cat([torch.ones_like(roll), torch.zeros_like(roll), torch.zeros_like(roll), 
+                          torch.zeros_like(roll), torch.cos(roll), -torch.sin(roll),
+                          torch.zeros_like(roll), torch.sin(roll), torch.cos(roll)], dim=1)
+        roll_mat = roll_mat.view(roll_mat.shape[0], 3, 3)
+
+        pitch_mat = torch.cat([torch.cos(pitch), torch.zeros_like(pitch), torch.sin(pitch), 
+                           torch.zeros_like(pitch), torch.ones_like(pitch), torch.zeros_like(pitch),
+                           -torch.sin(pitch), torch.zeros_like(pitch), torch.cos(pitch)], dim=1)
+        pitch_mat = pitch_mat.view(pitch_mat.shape[0], 3, 3)
+
+        yaw_mat = torch.cat([torch.cos(yaw), -torch.sin(yaw), torch.zeros_like(yaw),  
+                         torch.sin(yaw), torch.cos(yaw), torch.zeros_like(yaw),
+                         torch.zeros_like(yaw), torch.zeros_like(yaw), torch.ones_like(yaw)], dim=1)
+        yaw_mat = yaw_mat.view(yaw_mat.shape[0], 3, 3)
+
+        rot_mat = torch.einsum('bij,bjk,bkm->bim', roll_mat, pitch_mat, yaw_mat)
+
+        return rot_mat
+
+    def keypoint_transformation(self, kp_canonical, he):
+        kp = kp_canonical['value']    # (bs, k, 3)
+        yaw, pitch, roll = he['yaw'], he['pitch'], he['roll']
+        t, exp = he['t'], he['exp']
+    
+        yaw = self.headpose_pred_to_degree(yaw)
+        pitch = self.headpose_pred_to_degree(pitch)
+        roll = self.headpose_pred_to_degree(roll)
+
+        rot_mat = self.get_rotation_matrix(yaw, pitch, roll)    # (bs, 3, 3)
+    
+        # keypoint rotation
+        kp_rotated = torch.einsum('bmp,bkp->bkm', rot_mat, kp)
+
+    
+
+        # keypoint translation
+        t = t.unsqueeze_(1).repeat(1, kp.shape[1], 1)
+        kp_t = kp_rotated + t
+
+        # add expression deviation 
+        exp = exp.view(exp.shape[0], -1, 3)
+        kp_transformed = kp_t + exp
+
+        return {'value': kp_transformed}
+
+    def forward(self, source_image, target_audio):
+        pose_source = self.he_estimator_video(source_image)
+        pose_generated = self.he_estimator_audio(target_audio)
+        kp_canonical = self.kp_extractor(source_image)
+        kp_source = self.keypoint_transformation(kp_canonical, pose_source)
+        kp_transformed_generated = self.keypoint_transformation(kp_canonical, pose_generated)
+        generated = self.generator(source_image, kp_source=kp_source, kp_driving=kp_transformed_generated)
+        return generated
\ No newline at end of file
diff --git a/src/facerender/sync_batchnorm/__init__.py b/src/facerender/sync_batchnorm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..48871cdcdc882c903501ecc6d70fcb1b50bd7e9f
--- /dev/null
+++ b/src/facerender/sync_batchnorm/__init__.py
@@ -0,0 +1,12 @@
+# -*- coding: utf-8 -*-
+# File   : __init__.py
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+# Date   : 27/01/2018
+# 
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
+from .replicate import DataParallelWithCallback, patch_replication_callback
diff --git a/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3a3b4730a5580b232849b5349c41bffc5ed5ad9f
Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/__init__.cpython-38.pyc differ
diff --git a/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..437042c515e69a06dd426f60b2b91c230bebe5da
Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/batchnorm.cpython-38.pyc differ
diff --git a/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..3ac35043eddd75928477fe9a5232951fe770745e
Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/comm.cpython-38.pyc differ
diff --git a/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..258b14a0548aa56beb1c6ca817b53129e97377f7
Binary files /dev/null and b/src/facerender/sync_batchnorm/__pycache__/replicate.cpython-38.pyc differ
diff --git a/src/facerender/sync_batchnorm/batchnorm.py b/src/facerender/sync_batchnorm/batchnorm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4cc2ccd2f0c904cbe433fb6136f443f0fa86fa6
--- /dev/null
+++ b/src/facerender/sync_batchnorm/batchnorm.py
@@ -0,0 +1,315 @@
+# -*- coding: utf-8 -*-
+# File   : batchnorm.py
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+# Date   : 27/01/2018
+# 
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import collections
+
+import torch
+import torch.nn.functional as F
+
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
+
+from .comm import SyncMaster
+
+__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
+
+
+def _sum_ft(tensor):
+    """sum over the first and last dimention"""
+    return tensor.sum(dim=0).sum(dim=-1)
+
+
+def _unsqueeze_ft(tensor):
+    """add new dementions at the front and the tail"""
+    return tensor.unsqueeze(0).unsqueeze(-1)
+
+
+_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
+_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
+
+
+class _SynchronizedBatchNorm(_BatchNorm):
+    def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True):
+        super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
+
+        self._sync_master = SyncMaster(self._data_parallel_master)
+
+        self._is_parallel = False
+        self._parallel_id = None
+        self._slave_pipe = None
+
+    def forward(self, input):
+        # If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
+        if not (self._is_parallel and self.training):
+            return F.batch_norm(
+                input, self.running_mean, self.running_var, self.weight, self.bias,
+                self.training, self.momentum, self.eps)
+
+        # Resize the input to (B, C, -1).
+        input_shape = input.size()
+        input = input.view(input.size(0), self.num_features, -1)
+
+        # Compute the sum and square-sum.
+        sum_size = input.size(0) * input.size(2)
+        input_sum = _sum_ft(input)
+        input_ssum = _sum_ft(input ** 2)
+
+        # Reduce-and-broadcast the statistics.
+        if self._parallel_id == 0:
+            mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
+        else:
+            mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
+
+        # Compute the output.
+        if self.affine:
+            # MJY:: Fuse the multiplication for speed.
+            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
+        else:
+            output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
+
+        # Reshape it.
+        return output.view(input_shape)
+
+    def __data_parallel_replicate__(self, ctx, copy_id):
+        self._is_parallel = True
+        self._parallel_id = copy_id
+
+        # parallel_id == 0 means master device.
+        if self._parallel_id == 0:
+            ctx.sync_master = self._sync_master
+        else:
+            self._slave_pipe = ctx.sync_master.register_slave(copy_id)
+
+    def _data_parallel_master(self, intermediates):
+        """Reduce the sum and square-sum, compute the statistics, and broadcast it."""
+
+        # Always using same "device order" makes the ReduceAdd operation faster.
+        # Thanks to:: Tete Xiao (http://tetexiao.com/)
+        intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
+
+        to_reduce = [i[1][:2] for i in intermediates]
+        to_reduce = [j for i in to_reduce for j in i]  # flatten
+        target_gpus = [i[1].sum.get_device() for i in intermediates]
+
+        sum_size = sum([i[1].sum_size for i in intermediates])
+        sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
+        mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
+
+        broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
+
+        outputs = []
+        for i, rec in enumerate(intermediates):
+            outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
+
+        return outputs
+
+    def _compute_mean_std(self, sum_, ssum, size):
+        """Compute the mean and standard-deviation with sum and square-sum. This method
+        also maintains the moving average on the master device."""
+        assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
+        mean = sum_ / size
+        sumvar = ssum - sum_ * mean
+        unbias_var = sumvar / (size - 1)
+        bias_var = sumvar / size
+
+        self.running_mean = (1 - self.momentum) * self.running_mean + self.momentum * mean.data
+        self.running_var = (1 - self.momentum) * self.running_var + self.momentum * unbias_var.data
+
+        return mean, bias_var.clamp(self.eps) ** -0.5
+
+
+class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
+    r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
+    mini-batch.
+
+    .. math::
+
+        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+    This module differs from the built-in PyTorch BatchNorm1d as the mean and
+    standard-deviation are reduced across all devices during training.
+
+    For example, when one uses `nn.DataParallel` to wrap the network during
+    training, PyTorch's implementation normalize the tensor on each device using
+    the statistics only on that device, which accelerated the computation and
+    is also easy to implement, but the statistics might be inaccurate.
+    Instead, in this synchronized version, the statistics will be computed
+    over all training samples distributed on multiple devices.
+    
+    Note that, for one-GPU or CPU-only case, this module behaves exactly same
+    as the built-in PyTorch implementation.
+
+    The mean and standard-deviation are calculated per-dimension over
+    the mini-batches and gamma and beta are learnable parameter vectors
+    of size C (where C is the input size).
+
+    During training, this layer keeps a running estimate of its computed mean
+    and variance. The running sum is kept with a default momentum of 0.1.
+
+    During evaluation, this running mean/variance is used for normalization.
+
+    Because the BatchNorm is done over the `C` dimension, computing statistics
+    on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
+
+    Args:
+        num_features: num_features from an expected input of size
+            `batch_size x num_features [x width]`
+        eps: a value added to the denominator for numerical stability.
+            Default: 1e-5
+        momentum: the value used for the running_mean and running_var
+            computation. Default: 0.1
+        affine: a boolean value that when set to ``True``, gives the layer learnable
+            affine parameters. Default: ``True``
+
+    Shape:
+        - Input: :math:`(N, C)` or :math:`(N, C, L)`
+        - Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
+
+    Examples:
+        >>> # With Learnable Parameters
+        >>> m = SynchronizedBatchNorm1d(100)
+        >>> # Without Learnable Parameters
+        >>> m = SynchronizedBatchNorm1d(100, affine=False)
+        >>> input = torch.autograd.Variable(torch.randn(20, 100))
+        >>> output = m(input)
+    """
+
+    def _check_input_dim(self, input):
+        if input.dim() != 2 and input.dim() != 3:
+            raise ValueError('expected 2D or 3D input (got {}D input)'
+                             .format(input.dim()))
+        super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
+    r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
+    of 3d inputs
+
+    .. math::
+
+        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+    This module differs from the built-in PyTorch BatchNorm2d as the mean and
+    standard-deviation are reduced across all devices during training.
+
+    For example, when one uses `nn.DataParallel` to wrap the network during
+    training, PyTorch's implementation normalize the tensor on each device using
+    the statistics only on that device, which accelerated the computation and
+    is also easy to implement, but the statistics might be inaccurate.
+    Instead, in this synchronized version, the statistics will be computed
+    over all training samples distributed on multiple devices.
+    
+    Note that, for one-GPU or CPU-only case, this module behaves exactly same
+    as the built-in PyTorch implementation.
+
+    The mean and standard-deviation are calculated per-dimension over
+    the mini-batches and gamma and beta are learnable parameter vectors
+    of size C (where C is the input size).
+
+    During training, this layer keeps a running estimate of its computed mean
+    and variance. The running sum is kept with a default momentum of 0.1.
+
+    During evaluation, this running mean/variance is used for normalization.
+
+    Because the BatchNorm is done over the `C` dimension, computing statistics
+    on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
+
+    Args:
+        num_features: num_features from an expected input of
+            size batch_size x num_features x height x width
+        eps: a value added to the denominator for numerical stability.
+            Default: 1e-5
+        momentum: the value used for the running_mean and running_var
+            computation. Default: 0.1
+        affine: a boolean value that when set to ``True``, gives the layer learnable
+            affine parameters. Default: ``True``
+
+    Shape:
+        - Input: :math:`(N, C, H, W)`
+        - Output: :math:`(N, C, H, W)` (same shape as input)
+
+    Examples:
+        >>> # With Learnable Parameters
+        >>> m = SynchronizedBatchNorm2d(100)
+        >>> # Without Learnable Parameters
+        >>> m = SynchronizedBatchNorm2d(100, affine=False)
+        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
+        >>> output = m(input)
+    """
+
+    def _check_input_dim(self, input):
+        if input.dim() != 4:
+            raise ValueError('expected 4D input (got {}D input)'
+                             .format(input.dim()))
+        super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
+
+
+class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
+    r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
+    of 4d inputs
+
+    .. math::
+
+        y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
+
+    This module differs from the built-in PyTorch BatchNorm3d as the mean and
+    standard-deviation are reduced across all devices during training.
+
+    For example, when one uses `nn.DataParallel` to wrap the network during
+    training, PyTorch's implementation normalize the tensor on each device using
+    the statistics only on that device, which accelerated the computation and
+    is also easy to implement, but the statistics might be inaccurate.
+    Instead, in this synchronized version, the statistics will be computed
+    over all training samples distributed on multiple devices.
+    
+    Note that, for one-GPU or CPU-only case, this module behaves exactly same
+    as the built-in PyTorch implementation.
+
+    The mean and standard-deviation are calculated per-dimension over
+    the mini-batches and gamma and beta are learnable parameter vectors
+    of size C (where C is the input size).
+
+    During training, this layer keeps a running estimate of its computed mean
+    and variance. The running sum is kept with a default momentum of 0.1.
+
+    During evaluation, this running mean/variance is used for normalization.
+
+    Because the BatchNorm is done over the `C` dimension, computing statistics
+    on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
+    or Spatio-temporal BatchNorm
+
+    Args:
+        num_features: num_features from an expected input of
+            size batch_size x num_features x depth x height x width
+        eps: a value added to the denominator for numerical stability.
+            Default: 1e-5
+        momentum: the value used for the running_mean and running_var
+            computation. Default: 0.1
+        affine: a boolean value that when set to ``True``, gives the layer learnable
+            affine parameters. Default: ``True``
+
+    Shape:
+        - Input: :math:`(N, C, D, H, W)`
+        - Output: :math:`(N, C, D, H, W)` (same shape as input)
+
+    Examples:
+        >>> # With Learnable Parameters
+        >>> m = SynchronizedBatchNorm3d(100)
+        >>> # Without Learnable Parameters
+        >>> m = SynchronizedBatchNorm3d(100, affine=False)
+        >>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
+        >>> output = m(input)
+    """
+
+    def _check_input_dim(self, input):
+        if input.dim() != 5:
+            raise ValueError('expected 5D input (got {}D input)'
+                             .format(input.dim()))
+        super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
diff --git a/src/facerender/sync_batchnorm/comm.py b/src/facerender/sync_batchnorm/comm.py
new file mode 100644
index 0000000000000000000000000000000000000000..b66ec4aea213edf4330beda0a8c8b93d6db77a60
--- /dev/null
+++ b/src/facerender/sync_batchnorm/comm.py
@@ -0,0 +1,137 @@
+# -*- coding: utf-8 -*-
+# File   : comm.py
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+# Date   : 27/01/2018
+# 
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import queue
+import collections
+import threading
+
+__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
+
+
+class FutureResult(object):
+    """A thread-safe future implementation. Used only as one-to-one pipe."""
+
+    def __init__(self):
+        self._result = None
+        self._lock = threading.Lock()
+        self._cond = threading.Condition(self._lock)
+
+    def put(self, result):
+        with self._lock:
+            assert self._result is None, 'Previous result has\'t been fetched.'
+            self._result = result
+            self._cond.notify()
+
+    def get(self):
+        with self._lock:
+            if self._result is None:
+                self._cond.wait()
+
+            res = self._result
+            self._result = None
+            return res
+
+
+_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
+_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
+
+
+class SlavePipe(_SlavePipeBase):
+    """Pipe for master-slave communication."""
+
+    def run_slave(self, msg):
+        self.queue.put((self.identifier, msg))
+        ret = self.result.get()
+        self.queue.put(True)
+        return ret
+
+
+class SyncMaster(object):
+    """An abstract `SyncMaster` object.
+
+    - During the replication, as the data parallel will trigger an callback of each module, all slave devices should
+    call `register(id)` and obtain an `SlavePipe` to communicate with the master.
+    - During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
+    and passed to a registered callback.
+    - After receiving the messages, the master device should gather the information and determine to message passed
+    back to each slave devices.
+    """
+
+    def __init__(self, master_callback):
+        """
+
+        Args:
+            master_callback: a callback to be invoked after having collected messages from slave devices.
+        """
+        self._master_callback = master_callback
+        self._queue = queue.Queue()
+        self._registry = collections.OrderedDict()
+        self._activated = False
+
+    def __getstate__(self):
+        return {'master_callback': self._master_callback}
+
+    def __setstate__(self, state):
+        self.__init__(state['master_callback'])
+
+    def register_slave(self, identifier):
+        """
+        Register an slave device.
+
+        Args:
+            identifier: an identifier, usually is the device id.
+
+        Returns: a `SlavePipe` object which can be used to communicate with the master device.
+
+        """
+        if self._activated:
+            assert self._queue.empty(), 'Queue is not clean before next initialization.'
+            self._activated = False
+            self._registry.clear()
+        future = FutureResult()
+        self._registry[identifier] = _MasterRegistry(future)
+        return SlavePipe(identifier, self._queue, future)
+
+    def run_master(self, master_msg):
+        """
+        Main entry for the master device in each forward pass.
+        The messages were first collected from each devices (including the master device), and then
+        an callback will be invoked to compute the message to be sent back to each devices
+        (including the master device).
+
+        Args:
+            master_msg: the message that the master want to send to itself. This will be placed as the first
+            message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
+
+        Returns: the message to be sent back to the master device.
+
+        """
+        self._activated = True
+
+        intermediates = [(0, master_msg)]
+        for i in range(self.nr_slaves):
+            intermediates.append(self._queue.get())
+
+        results = self._master_callback(intermediates)
+        assert results[0][0] == 0, 'The first result should belongs to the master.'
+
+        for i, res in results:
+            if i == 0:
+                continue
+            self._registry[i].result.put(res)
+
+        for i in range(self.nr_slaves):
+            assert self._queue.get() is True
+
+        return results[0][1]
+
+    @property
+    def nr_slaves(self):
+        return len(self._registry)
diff --git a/src/facerender/sync_batchnorm/replicate.py b/src/facerender/sync_batchnorm/replicate.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b97380d9c5fbe75c4b3583d3668ccd6a2848699
--- /dev/null
+++ b/src/facerender/sync_batchnorm/replicate.py
@@ -0,0 +1,94 @@
+# -*- coding: utf-8 -*-
+# File   : replicate.py
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+# Date   : 27/01/2018
+# 
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import functools
+
+from torch.nn.parallel.data_parallel import DataParallel
+
+__all__ = [
+    'CallbackContext',
+    'execute_replication_callbacks',
+    'DataParallelWithCallback',
+    'patch_replication_callback'
+]
+
+
+class CallbackContext(object):
+    pass
+
+
+def execute_replication_callbacks(modules):
+    """
+    Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
+
+    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+    Note that, as all modules are isomorphism, we assign each sub-module with a context
+    (shared among multiple copies of this module on different devices).
+    Through this context, different copies can share some information.
+
+    We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
+    of any slave copies.
+    """
+    master_copy = modules[0]
+    nr_modules = len(list(master_copy.modules()))
+    ctxs = [CallbackContext() for _ in range(nr_modules)]
+
+    for i, module in enumerate(modules):
+        for j, m in enumerate(module.modules()):
+            if hasattr(m, '__data_parallel_replicate__'):
+                m.__data_parallel_replicate__(ctxs[j], i)
+
+
+class DataParallelWithCallback(DataParallel):
+    """
+    Data Parallel with a replication callback.
+
+    An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
+    original `replicate` function.
+    The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
+
+    Examples:
+        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+        # sync_bn.__data_parallel_replicate__ will be invoked.
+    """
+
+    def replicate(self, module, device_ids):
+        modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
+        execute_replication_callbacks(modules)
+        return modules
+
+
+def patch_replication_callback(data_parallel):
+    """
+    Monkey-patch an existing `DataParallel` object. Add the replication callback.
+    Useful when you have customized `DataParallel` implementation.
+
+    Examples:
+        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+        > sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
+        > patch_replication_callback(sync_bn)
+        # this is equivalent to
+        > sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
+        > sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
+    """
+
+    assert isinstance(data_parallel, DataParallel)
+
+    old_replicate = data_parallel.replicate
+
+    @functools.wraps(old_replicate)
+    def new_replicate(module, device_ids):
+        modules = old_replicate(module, device_ids)
+        execute_replication_callbacks(modules)
+        return modules
+
+    data_parallel.replicate = new_replicate
diff --git a/src/facerender/sync_batchnorm/unittest.py b/src/facerender/sync_batchnorm/unittest.py
new file mode 100644
index 0000000000000000000000000000000000000000..9716d035495097fb086ec050ab0bc9b76b9d28a0
--- /dev/null
+++ b/src/facerender/sync_batchnorm/unittest.py
@@ -0,0 +1,29 @@
+# -*- coding: utf-8 -*-
+# File   : unittest.py
+# Author : Jiayuan Mao
+# Email  : maojiayuan@gmail.com
+# Date   : 27/01/2018
+# 
+# This file is part of Synchronized-BatchNorm-PyTorch.
+# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
+# Distributed under MIT License.
+
+import unittest
+
+import numpy as np
+from torch.autograd import Variable
+
+
+def as_numpy(v):
+    if isinstance(v, Variable):
+        v = v.data
+    return v.cpu().numpy()
+
+
+class TorchTestCase(unittest.TestCase):
+    def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
+        npa, npb = as_numpy(a), as_numpy(b)
+        self.assertTrue(
+                np.allclose(npa, npb, atol=atol),
+                'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
+        )
diff --git a/src/generate_batch.py b/src/generate_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..41c7b1c92cefcaf83215159977259bd9e60f5f61
--- /dev/null
+++ b/src/generate_batch.py
@@ -0,0 +1,115 @@
+import os
+
+from tqdm import tqdm
+import torch
+import numpy as np
+import random
+import scipy.io as scio
+import src.utils.audio as audio
+
+def crop_pad_audio(wav, audio_length):
+    if len(wav) > audio_length:
+        wav = wav[:audio_length]
+    elif len(wav) < audio_length:
+        wav = np.pad(wav, [0, audio_length - len(wav)], mode='constant', constant_values=0)
+    return wav
+
+def parse_audio_length(audio_length, sr, fps):
+    bit_per_frames = sr / fps
+
+    num_frames = int(audio_length / bit_per_frames)
+    audio_length = int(num_frames * bit_per_frames)
+
+    return audio_length, num_frames
+
+def generate_blink_seq(num_frames):
+    ratio = np.zeros((num_frames,1))
+    frame_id = 0
+    while frame_id in range(num_frames):
+        start = 80
+        if frame_id+start+9<=num_frames - 1:
+            ratio[frame_id+start:frame_id+start+9, 0] = [0.5,0.6,0.7,0.9,1, 0.9, 0.7,0.6,0.5]
+            frame_id = frame_id+start+9
+        else:
+            break
+    return ratio 
+
+def generate_blink_seq_randomly(num_frames):
+    ratio = np.zeros((num_frames,1))
+    if num_frames<=20:
+        return ratio
+    frame_id = 0
+    while frame_id in range(num_frames):
+        start = random.choice(range(min(10,num_frames), min(int(num_frames/2), 70))) 
+        if frame_id+start+5<=num_frames - 1:
+            ratio[frame_id+start:frame_id+start+5, 0] = [0.5, 0.9, 1.0, 0.9, 0.5]
+            frame_id = frame_id+start+5
+        else:
+            break
+    return ratio
+
+def get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=False):
+
+    syncnet_mel_step_size = 16
+    fps = 25
+
+    pic_name = os.path.splitext(os.path.split(first_coeff_path)[-1])[0]
+    audio_name = os.path.splitext(os.path.split(audio_path)[-1])[0]
+
+    wav = audio.load_wav(audio_path, 16000) 
+    wav_length, num_frames = parse_audio_length(len(wav), 16000, 25)
+    wav = crop_pad_audio(wav, wav_length)
+    orig_mel = audio.melspectrogram(wav).T
+    spec = orig_mel.copy()         # nframes 80
+    indiv_mels = []
+
+    for i in tqdm(range(num_frames), 'mel:'):
+        start_frame_num = i-2
+        start_idx = int(80. * (start_frame_num / float(fps)))
+        end_idx = start_idx + syncnet_mel_step_size
+        seq = list(range(start_idx, end_idx))
+        seq = [ min(max(item, 0), orig_mel.shape[0]-1) for item in seq ]
+        m = spec[seq, :]
+        indiv_mels.append(m.T)
+    indiv_mels = np.asarray(indiv_mels)         # T 80 16
+
+    ratio = generate_blink_seq_randomly(num_frames)      # T
+    source_semantics_path = first_coeff_path
+    source_semantics_dict = scio.loadmat(source_semantics_path)
+    ref_coeff = source_semantics_dict['coeff_3dmm'][:1,:70]         #1 70
+    ref_coeff = np.repeat(ref_coeff, num_frames, axis=0)
+
+    if ref_eyeblink_coeff_path is not None:
+        ratio[:num_frames] = 0
+        refeyeblink_coeff_dict = scio.loadmat(ref_eyeblink_coeff_path)
+        refeyeblink_coeff = refeyeblink_coeff_dict['coeff_3dmm'][:,:64]
+        refeyeblink_num_frames = refeyeblink_coeff.shape[0]
+        if refeyeblink_num_frames<num_frames:
+            div = num_frames//refeyeblink_num_frames
+            re = num_frames%refeyeblink_num_frames
+            refeyeblink_coeff_list = [refeyeblink_coeff for i in range(div)]
+            refeyeblink_coeff_list.append(refeyeblink_coeff[:re, :64])
+            refeyeblink_coeff = np.concatenate(refeyeblink_coeff_list, axis=0)
+            print(refeyeblink_coeff.shape[0])
+
+        ref_coeff[:, :64] = refeyeblink_coeff[:num_frames, :64] 
+    
+    indiv_mels = torch.FloatTensor(indiv_mels).unsqueeze(1).unsqueeze(0) # bs T 1 80 16
+
+    if still:
+        ratio = torch.FloatTensor(ratio).unsqueeze(0).fill_(0.)                        # bs T
+    else:
+        ratio = torch.FloatTensor(ratio).unsqueeze(0)
+                               # bs T
+    ref_coeff = torch.FloatTensor(ref_coeff).unsqueeze(0)                # bs 1 70
+
+    indiv_mels = indiv_mels.to(device)
+    ratio = ratio.to(device)
+    ref_coeff = ref_coeff.to(device)
+
+    return {'indiv_mels': indiv_mels,  
+            'ref': ref_coeff, 
+            'num_frames': num_frames, 
+            'ratio_gt': ratio,
+            'audio_name': audio_name, 'pic_name': pic_name}
+
diff --git a/src/generate_facerender_batch.py b/src/generate_facerender_batch.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d19fc0ff9275d6e59c36b74e5bb4ec7b4cc63c3
--- /dev/null
+++ b/src/generate_facerender_batch.py
@@ -0,0 +1,134 @@
+import os
+import numpy as np
+from PIL import Image
+from skimage import io, img_as_float32, transform
+import torch
+import scipy.io as scio
+
+def get_facerender_data(coeff_path, pic_path, first_coeff_path, audio_path, 
+                        batch_size, input_yaw_list=None, input_pitch_list=None, input_roll_list=None, 
+                        expression_scale=1.0, still_mode = False, preprocess='crop'):
+
+    semantic_radius = 13
+    video_name = os.path.splitext(os.path.split(coeff_path)[-1])[0]
+    txt_path = os.path.splitext(coeff_path)[0]
+
+    data={}
+
+    img1 = Image.open(pic_path)
+    source_image = np.array(img1)
+    source_image = img_as_float32(source_image)
+    source_image = transform.resize(source_image, (256, 256, 3))
+    source_image = source_image.transpose((2, 0, 1))
+    source_image_ts = torch.FloatTensor(source_image).unsqueeze(0)
+    source_image_ts = source_image_ts.repeat(batch_size, 1, 1, 1)
+    data['source_image'] = source_image_ts
+ 
+    source_semantics_dict = scio.loadmat(first_coeff_path)
+
+    if preprocess.lower() != 'full':
+        source_semantics = source_semantics_dict['coeff_3dmm'][:1,:70]         #1 70
+    else:
+        source_semantics = source_semantics_dict['coeff_3dmm'][:1,:73]         #1 70
+
+    source_semantics_new = transform_semantic_1(source_semantics, semantic_radius)
+    source_semantics_ts = torch.FloatTensor(source_semantics_new).unsqueeze(0)
+    source_semantics_ts = source_semantics_ts.repeat(batch_size, 1, 1)
+    data['source_semantics'] = source_semantics_ts
+
+    # target 
+    generated_dict = scio.loadmat(coeff_path)
+    generated_3dmm = generated_dict['coeff_3dmm']
+    generated_3dmm[:, :64] = generated_3dmm[:, :64] * expression_scale
+
+    if preprocess.lower() == 'full':
+        generated_3dmm = np.concatenate([generated_3dmm, np.repeat(source_semantics[:,70:], generated_3dmm.shape[0], axis=0)], axis=1)
+
+    if still_mode:
+        generated_3dmm[:, 64:] = np.repeat(source_semantics[:, 64:], generated_3dmm.shape[0], axis=0)
+
+    with open(txt_path+'.txt', 'w') as f:
+        for coeff in generated_3dmm:
+            for i in coeff:
+                f.write(str(i)[:7]   + '  '+'\t')
+            f.write('\n')
+
+    target_semantics_list = [] 
+    frame_num = generated_3dmm.shape[0]
+    data['frame_num'] = frame_num
+    for frame_idx in range(frame_num):
+        target_semantics = transform_semantic_target(generated_3dmm, frame_idx, semantic_radius)
+        target_semantics_list.append(target_semantics)
+
+    remainder = frame_num%batch_size
+    if remainder!=0:
+        for _ in range(batch_size-remainder):
+            target_semantics_list.append(target_semantics)
+
+    target_semantics_np = np.array(target_semantics_list)             #frame_num 70 semantic_radius*2+1
+    target_semantics_np = target_semantics_np.reshape(batch_size, -1, target_semantics_np.shape[-2], target_semantics_np.shape[-1])
+    data['target_semantics_list'] = torch.FloatTensor(target_semantics_np)
+    data['video_name'] = video_name
+    data['audio_path'] = audio_path
+    
+    if input_yaw_list is not None:
+        yaw_c_seq = gen_camera_pose(input_yaw_list, frame_num, batch_size)
+        data['yaw_c_seq'] = torch.FloatTensor(yaw_c_seq)
+    if input_pitch_list is not None:
+        pitch_c_seq = gen_camera_pose(input_pitch_list, frame_num, batch_size)
+        data['pitch_c_seq'] = torch.FloatTensor(pitch_c_seq)
+    if input_roll_list is not None:
+        roll_c_seq = gen_camera_pose(input_roll_list, frame_num, batch_size) 
+        data['roll_c_seq'] = torch.FloatTensor(roll_c_seq)
+ 
+    return data
+
+def transform_semantic_1(semantic, semantic_radius):
+    semantic_list =  [semantic for i in range(0, semantic_radius*2+1)]
+    coeff_3dmm = np.concatenate(semantic_list, 0)
+    return coeff_3dmm.transpose(1,0)
+
+def transform_semantic_target(coeff_3dmm, frame_index, semantic_radius):
+    num_frames = coeff_3dmm.shape[0]
+    seq = list(range(frame_index- semantic_radius, frame_index + semantic_radius+1))
+    index = [ min(max(item, 0), num_frames-1) for item in seq ] 
+    coeff_3dmm_g = coeff_3dmm[index, :]
+    return coeff_3dmm_g.transpose(1,0)
+
+def gen_camera_pose(camera_degree_list, frame_num, batch_size):
+
+    new_degree_list = [] 
+    if len(camera_degree_list) == 1:
+        for _ in range(frame_num):
+            new_degree_list.append(camera_degree_list[0]) 
+        remainder = frame_num%batch_size
+        if remainder!=0:
+            for _ in range(batch_size-remainder):
+                new_degree_list.append(new_degree_list[-1])
+        new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) 
+        return new_degree_np
+
+    degree_sum = 0.
+    for i, degree in enumerate(camera_degree_list[1:]):
+        degree_sum += abs(degree-camera_degree_list[i])
+    
+    degree_per_frame = degree_sum/(frame_num-1)
+    for i, degree in enumerate(camera_degree_list[1:]):
+        degree_last = camera_degree_list[i]
+        degree_step = degree_per_frame * abs(degree-degree_last)/(degree-degree_last)
+        new_degree_list =  new_degree_list + list(np.arange(degree_last, degree, degree_step))
+    if len(new_degree_list) > frame_num:
+        new_degree_list = new_degree_list[:frame_num]
+    elif len(new_degree_list) < frame_num:
+        for _ in range(frame_num-len(new_degree_list)):
+            new_degree_list.append(new_degree_list[-1])
+    print(len(new_degree_list))
+    print(frame_num)
+
+    remainder = frame_num%batch_size
+    if remainder!=0:
+        for _ in range(batch_size-remainder):
+            new_degree_list.append(new_degree_list[-1])
+    new_degree_np = np.array(new_degree_list).reshape(batch_size, -1) 
+    return new_degree_np
+    
diff --git a/src/gradio_demo.py b/src/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4bdb2177f85333930f0ade60b3d6ac65c87f863d
--- /dev/null
+++ b/src/gradio_demo.py
@@ -0,0 +1,139 @@
+import torch, uuid
+import os, sys, shutil
+from src.utils.preprocess import CropAndExtract
+from src.test_audio2coeff import Audio2Coeff  
+from src.facerender.animate import AnimateFromCoeff
+from src.generate_batch import get_data
+from src.generate_facerender_batch import get_facerender_data
+
+from pydub import AudioSegment
+
+def mp3_to_wav(mp3_filename,wav_filename,frame_rate):
+    mp3_file = AudioSegment.from_file(file=mp3_filename)
+    mp3_file.set_frame_rate(frame_rate).export(wav_filename,format="wav")
+
+
+class SadTalker():
+
+    def __init__(self, checkpoint_path='checkpoints', config_path='src/config', lazy_load=False):
+
+        if torch.cuda.is_available() :
+            device = "cuda"
+        else:
+            device = "cpu"
+        
+        self.device = device
+
+        os.environ['TORCH_HOME']= checkpoint_path
+
+        self.checkpoint_path = checkpoint_path
+        self.config_path = config_path
+
+        self.path_of_lm_croper = os.path.join( checkpoint_path, 'shape_predictor_68_face_landmarks.dat')
+        self.path_of_net_recon_model = os.path.join( checkpoint_path, 'epoch_20.pth')
+        self.dir_of_BFM_fitting = os.path.join( checkpoint_path, 'BFM_Fitting')
+        self.wav2lip_checkpoint = os.path.join( checkpoint_path, 'wav2lip.pth')
+
+        self.audio2pose_checkpoint = os.path.join( checkpoint_path, 'auido2pose_00140-model.pth')
+        self.audio2pose_yaml_path = os.path.join( config_path, 'auido2pose.yaml')
+    
+        self.audio2exp_checkpoint = os.path.join( checkpoint_path, 'auido2exp_00300-model.pth')
+        self.audio2exp_yaml_path = os.path.join( config_path, 'auido2exp.yaml')
+
+        self.free_view_checkpoint = os.path.join( checkpoint_path, 'facevid2vid_00189-model.pth.tar')
+
+        self.lazy_load = lazy_load
+
+        if not self.lazy_load:
+            #init model
+            
+            print(self.audio2pose_checkpoint)
+            self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path, 
+                                    self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
+
+            print(self.path_of_lm_croper)
+            self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
+
+    def test(self, source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False, result_dir='./results/'):
+
+        ### crop: only model,
+
+        if self.lazy_load:
+            #init model
+            
+            print(self.audio2pose_checkpoint)
+            self.audio_to_coeff = Audio2Coeff(self.audio2pose_checkpoint, self.audio2pose_yaml_path, 
+                                    self.audio2exp_checkpoint, self.audio2exp_yaml_path, self.wav2lip_checkpoint, self.device)
+        
+            print(self.path_of_lm_croper)
+            self.preprocess_model = CropAndExtract(self.path_of_lm_croper, self.path_of_net_recon_model, self.dir_of_BFM_fitting, self.device)
+
+        if preprocess == 'full': 
+            self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00109-model.pth.tar')
+            self.facerender_yaml_path = os.path.join(self.config_path, 'facerender_still.yaml')
+        else:
+            self.mapping_checkpoint = os.path.join(self.checkpoint_path, 'mapping_00229-model.pth.tar')
+            self.facerender_yaml_path = os.path.join(self.config_path, 'facerender.yaml')
+
+        print(self.free_view_checkpoint)
+        self.animate_from_coeff = AnimateFromCoeff(self.free_view_checkpoint, self.mapping_checkpoint, 
+                                            self.facerender_yaml_path, self.device)
+
+        time_tag = str(uuid.uuid4())
+        save_dir = os.path.join(result_dir, time_tag)
+        os.makedirs(save_dir, exist_ok=True)
+
+        input_dir = os.path.join(save_dir, 'input')
+        os.makedirs(input_dir, exist_ok=True)
+
+        print(source_image)
+        pic_path = os.path.join(input_dir, os.path.basename(source_image)) 
+        shutil.move(source_image, input_dir)
+
+        if os.path.isfile(driven_audio):
+            audio_path = os.path.join(input_dir, os.path.basename(driven_audio))  
+
+            #### mp3 to wav
+            if '.mp3' in audio_path:
+                mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
+                audio_path = audio_path.replace('.mp3', '.wav')
+            else:
+                shutil.move(driven_audio, input_dir)
+        else:
+            raise AttributeError("error audio")
+
+
+        os.makedirs(save_dir, exist_ok=True)
+        pose_style = 0
+        #crop image and extract 3dmm from image
+        first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
+        os.makedirs(first_frame_dir, exist_ok=True)
+        first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(pic_path, first_frame_dir, preprocess)
+        
+        if first_coeff_path is None:
+            raise AttributeError("No face is detected")
+
+        #audio2ceoff
+        batch = get_data(first_coeff_path, audio_path, self.device, ref_eyeblink_coeff_path=None, still=still_mode) # longer audio?
+        coeff_path = self.audio_to_coeff.generate(batch, save_dir, pose_style)
+        #coeff2video
+        batch_size = 2
+        data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode, preprocess=preprocess)
+        return_path = self.animate_from_coeff.generate(data, save_dir,  pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess)
+        video_name = data['video_name']
+        print(f'The generated video is named {video_name} in {save_dir}')
+
+        if self.lazy_load:
+            del self.preprocess_model
+            del self.audio_to_coeff
+            del self.animate_from_coeff
+
+        if torch.cuda.is_available():
+            torch.cuda.empty_cache()
+            torch.cuda.synchronize()
+            
+        import gc; gc.collect()
+        
+        return return_path
+
+    
\ No newline at end of file
diff --git a/src/test_audio2coeff.py b/src/test_audio2coeff.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2a4bc50f3abb5b69545a0d09c5abbfd5c6ea1d1
--- /dev/null
+++ b/src/test_audio2coeff.py
@@ -0,0 +1,112 @@
+import os 
+import torch
+import numpy as np
+from scipy.io import savemat, loadmat
+from yacs.config import CfgNode as CN
+from scipy.signal import savgol_filter
+
+from src.audio2pose_models.audio2pose import Audio2Pose
+from src.audio2exp_models.networks import SimpleWrapperV2 
+from src.audio2exp_models.audio2exp import Audio2Exp  
+
+def load_cpk(checkpoint_path, model=None, optimizer=None, device="cpu"):
+    checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
+    if model is not None:
+        model.load_state_dict(checkpoint['model'])
+    if optimizer is not None:
+        optimizer.load_state_dict(checkpoint['optimizer'])
+
+    return checkpoint['epoch']
+
+class Audio2Coeff():
+
+    def __init__(self, audio2pose_checkpoint, audio2pose_yaml_path, 
+                        audio2exp_checkpoint, audio2exp_yaml_path, 
+                        wav2lip_checkpoint, device):
+        #load config
+        fcfg_pose = open(audio2pose_yaml_path)
+        cfg_pose = CN.load_cfg(fcfg_pose)
+        cfg_pose.freeze()
+        fcfg_exp = open(audio2exp_yaml_path)
+        cfg_exp = CN.load_cfg(fcfg_exp)
+        cfg_exp.freeze()
+
+        # load audio2pose_model
+        self.audio2pose_model = Audio2Pose(cfg_pose, wav2lip_checkpoint, device=device)
+        self.audio2pose_model = self.audio2pose_model.to(device)
+        self.audio2pose_model.eval()
+        for param in self.audio2pose_model.parameters():
+            param.requires_grad = False 
+        try:
+            load_cpk(audio2pose_checkpoint, model=self.audio2pose_model, device=device)
+        except:
+            raise Exception("Failed in loading audio2pose_checkpoint")
+
+        # load audio2exp_model
+        netG = SimpleWrapperV2()
+        netG = netG.to(device)
+        for param in netG.parameters():
+            netG.requires_grad = False
+        netG.eval()
+        try:
+            load_cpk(audio2exp_checkpoint, model=netG, device=device)
+        except:
+            raise Exception("Failed in loading audio2exp_checkpoint")
+        self.audio2exp_model = Audio2Exp(netG, cfg_exp, device=device, prepare_training_loss=False)
+        self.audio2exp_model = self.audio2exp_model.to(device)
+        for param in self.audio2exp_model.parameters():
+            param.requires_grad = False
+        self.audio2exp_model.eval()
+ 
+        self.device = device
+
+    def generate(self, batch, coeff_save_dir, pose_style, ref_pose_coeff_path=None):
+
+        with torch.no_grad():
+            #test
+            results_dict_exp= self.audio2exp_model.test(batch)
+            exp_pred = results_dict_exp['exp_coeff_pred']                         #bs T 64
+
+            #for class_id in  range(1):
+            #class_id = 0#(i+10)%45
+            #class_id = random.randint(0,46)                                   #46 styles can be selected 
+            batch['class'] = torch.LongTensor([pose_style]).to(self.device)
+            results_dict_pose = self.audio2pose_model.test(batch) 
+            pose_pred = results_dict_pose['pose_pred']                        #bs T 6
+
+            pose_len = pose_pred.shape[1]
+            if pose_len<13: 
+                pose_len = int((pose_len-1)/2)*2+1
+                pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), pose_len, 2, axis=1)).to(self.device)
+            else:
+                pose_pred = torch.Tensor(savgol_filter(np.array(pose_pred.cpu()), 13, 2, axis=1)).to(self.device) 
+            
+            coeffs_pred = torch.cat((exp_pred, pose_pred), dim=-1)            #bs T 70
+
+            coeffs_pred_numpy = coeffs_pred[0].clone().detach().cpu().numpy() 
+
+            
+            if ref_pose_coeff_path is not None: 
+                 coeffs_pred_numpy = self.using_refpose(coeffs_pred_numpy, ref_pose_coeff_path)
+        
+            savemat(os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name'])),  
+                    {'coeff_3dmm': coeffs_pred_numpy})
+
+            return os.path.join(coeff_save_dir, '%s##%s.mat'%(batch['pic_name'], batch['audio_name']))
+    
+    def using_refpose(self, coeffs_pred_numpy, ref_pose_coeff_path):
+        num_frames = coeffs_pred_numpy.shape[0]
+        refpose_coeff_dict = loadmat(ref_pose_coeff_path)
+        refpose_coeff = refpose_coeff_dict['coeff_3dmm'][:,64:70]
+        refpose_num_frames = refpose_coeff.shape[0]
+        if refpose_num_frames<num_frames:
+            div = num_frames//refpose_num_frames
+            re = num_frames%refpose_num_frames
+            refpose_coeff_list = [refpose_coeff for i in range(div)]
+            refpose_coeff_list.append(refpose_coeff[:re, :])
+            refpose_coeff = np.concatenate(refpose_coeff_list, axis=0)
+
+        coeffs_pred_numpy[:, 64:70] = refpose_coeff[:num_frames, :] 
+        return coeffs_pred_numpy
+
+
diff --git a/src/utils/__pycache__/audio.cpython-38.pyc b/src/utils/__pycache__/audio.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2b0fc55558968da90bcdeab8399ad2fdda8eb574
Binary files /dev/null and b/src/utils/__pycache__/audio.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/croper.cpython-38.pyc b/src/utils/__pycache__/croper.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..90b52c04644fedc8979634d4a311556a398a2203
Binary files /dev/null and b/src/utils/__pycache__/croper.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/face_enhancer.cpython-38.pyc b/src/utils/__pycache__/face_enhancer.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..eba1639d55a59735c74fbef9e66392fa91691348
Binary files /dev/null and b/src/utils/__pycache__/face_enhancer.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/hparams.cpython-38.pyc b/src/utils/__pycache__/hparams.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..d3706ab887772c57917026174ea656358471c9e2
Binary files /dev/null and b/src/utils/__pycache__/hparams.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/paste_pic.cpython-38.pyc b/src/utils/__pycache__/paste_pic.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..c5ce7d3b1f930a266aa05c1fdb86f8924a59f5a4
Binary files /dev/null and b/src/utils/__pycache__/paste_pic.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/preprocess.cpython-38.pyc b/src/utils/__pycache__/preprocess.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8cef26d2d7f449d3ad4f98267867ea28067e8552
Binary files /dev/null and b/src/utils/__pycache__/preprocess.cpython-38.pyc differ
diff --git a/src/utils/__pycache__/preprocess.cpython-39.pyc b/src/utils/__pycache__/preprocess.cpython-39.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..530e3c27b68daaf86aff084b4736dc5d05634811
Binary files /dev/null and b/src/utils/__pycache__/preprocess.cpython-39.pyc differ
diff --git a/src/utils/__pycache__/videoio.cpython-38.pyc b/src/utils/__pycache__/videoio.cpython-38.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..72a80a4a91de1f4ab07ffdcf54466e1b46e724be
Binary files /dev/null and b/src/utils/__pycache__/videoio.cpython-38.pyc differ
diff --git a/src/utils/audio.py b/src/utils/audio.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3387db15f4528ea07ff2ed7435569956697f633
--- /dev/null
+++ b/src/utils/audio.py
@@ -0,0 +1,136 @@
+import librosa
+import librosa.filters
+import numpy as np
+# import tensorflow as tf
+from scipy import signal
+from scipy.io import wavfile
+from src.utils.hparams import hparams as hp
+
+def load_wav(path, sr):
+    return librosa.core.load(path, sr=sr)[0]
+
+def save_wav(wav, path, sr):
+    wav *= 32767 / max(0.01, np.max(np.abs(wav)))
+    #proposed by @dsmiller
+    wavfile.write(path, sr, wav.astype(np.int16))
+
+def save_wavenet_wav(wav, path, sr):
+    librosa.output.write_wav(path, wav, sr=sr)
+
+def preemphasis(wav, k, preemphasize=True):
+    if preemphasize:
+        return signal.lfilter([1, -k], [1], wav)
+    return wav
+
+def inv_preemphasis(wav, k, inv_preemphasize=True):
+    if inv_preemphasize:
+        return signal.lfilter([1], [1, -k], wav)
+    return wav
+
+def get_hop_size():
+    hop_size = hp.hop_size
+    if hop_size is None:
+        assert hp.frame_shift_ms is not None
+        hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
+    return hop_size
+
+def linearspectrogram(wav):
+    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
+    S = _amp_to_db(np.abs(D)) - hp.ref_level_db
+    
+    if hp.signal_normalization:
+        return _normalize(S)
+    return S
+
+def melspectrogram(wav):
+    D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
+    S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
+    
+    if hp.signal_normalization:
+        return _normalize(S)
+    return S
+
+def _lws_processor():
+    import lws
+    return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
+
+def _stft(y):
+    if hp.use_lws:
+        return _lws_processor(hp).stft(y).T
+    else:
+        return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
+
+##########################################################
+#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
+def num_frames(length, fsize, fshift):
+    """Compute number of time frames of spectrogram
+    """
+    pad = (fsize - fshift)
+    if length % fshift == 0:
+        M = (length + pad * 2 - fsize) // fshift + 1
+    else:
+        M = (length + pad * 2 - fsize) // fshift + 2
+    return M
+
+
+def pad_lr(x, fsize, fshift):
+    """Compute left and right padding
+    """
+    M = num_frames(len(x), fsize, fshift)
+    pad = (fsize - fshift)
+    T = len(x) + 2 * pad
+    r = (M - 1) * fshift + fsize - T
+    return pad, pad + r
+##########################################################
+#Librosa correct padding
+def librosa_pad_lr(x, fsize, fshift):
+    return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
+
+# Conversions
+_mel_basis = None
+
+def _linear_to_mel(spectogram):
+    global _mel_basis
+    if _mel_basis is None:
+        _mel_basis = _build_mel_basis()
+    return np.dot(_mel_basis, spectogram)
+
+def _build_mel_basis():
+    assert hp.fmax <= hp.sample_rate // 2
+    return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
+                               fmin=hp.fmin, fmax=hp.fmax)
+
+def _amp_to_db(x):
+    min_level = np.exp(hp.min_level_db / 20 * np.log(10))
+    return 20 * np.log10(np.maximum(min_level, x))
+
+def _db_to_amp(x):
+    return np.power(10.0, (x) * 0.05)
+
+def _normalize(S):
+    if hp.allow_clipping_in_normalization:
+        if hp.symmetric_mels:
+            return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
+                           -hp.max_abs_value, hp.max_abs_value)
+        else:
+            return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
+    
+    assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
+    if hp.symmetric_mels:
+        return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
+    else:
+        return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
+
+def _denormalize(D):
+    if hp.allow_clipping_in_normalization:
+        if hp.symmetric_mels:
+            return (((np.clip(D, -hp.max_abs_value,
+                              hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
+                    + hp.min_level_db)
+        else:
+            return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
+    
+    if hp.symmetric_mels:
+        return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
+    else:
+        return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
diff --git a/src/utils/croper.py b/src/utils/croper.py
new file mode 100644
index 0000000000000000000000000000000000000000..47684264d4d492f6441cb71b8d481d8a53e956a0
--- /dev/null
+++ b/src/utils/croper.py
@@ -0,0 +1,299 @@
+import os
+import cv2
+import time
+import glob
+import argparse
+import scipy
+import numpy as np
+from PIL import Image
+from tqdm import tqdm
+from itertools import cycle
+
+from torch.multiprocessing import Pool, Process, set_start_method
+
+
+"""
+brief: face alignment with FFHQ method (https://github.com/NVlabs/ffhq-dataset)
+author: lzhbrian (https://lzhbrian.me)
+date: 2020.1.5
+note: code is heavily borrowed from 
+    https://github.com/NVlabs/ffhq-dataset
+    http://dlib.net/face_landmark_detection.py.html
+requirements:
+    apt install cmake
+    conda install Pillow numpy scipy
+    pip install dlib
+    # download face landmark model from: 
+    # http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+"""
+
+import numpy as np
+from PIL import Image
+import dlib
+
+
+class Croper:
+    def __init__(self, path_of_lm):
+        # download model from: http://dlib.net/files/shape_predictor_68_face_landmarks.dat.bz2
+        self.predictor = dlib.shape_predictor(path_of_lm)
+
+    def get_landmark(self, img_np):
+        """get landmark with dlib
+        :return: np.array shape=(68, 2)
+        """
+        detector = dlib.get_frontal_face_detector()
+        dets = detector(img_np, 1)
+        #     print("Number of faces detected: {}".format(len(dets)))
+        #     for k, d in enumerate(dets):
+        if len(dets) == 0:
+            return None
+        d = dets[0]
+        # Get the landmarks/parts for the face in box d.
+        shape = self.predictor(img_np, d)
+        #         print("Part 0: {}, Part 1: {} ...".format(shape.part(0), shape.part(1)))
+        t = list(shape.parts())
+        a = []
+        for tt in t:
+            a.append([tt.x, tt.y])
+        lm = np.array(a)
+        # lm is a shape=(68,2) np.array
+        return lm
+
+    def align_face(self, img, lm, output_size=1024):
+        """
+        :param filepath: str
+        :return: PIL Image
+        """
+        lm_chin = lm[0: 17]  # left-right
+        lm_eyebrow_left = lm[17: 22]  # left-right
+        lm_eyebrow_right = lm[22: 27]  # left-right
+        lm_nose = lm[27: 31]  # top-down
+        lm_nostrils = lm[31: 36]  # top-down
+        lm_eye_left = lm[36: 42]  # left-clockwise
+        lm_eye_right = lm[42: 48]  # left-clockwise
+        lm_mouth_outer = lm[48: 60]  # left-clockwise
+        lm_mouth_inner = lm[60: 68]  # left-clockwise
+
+        # Calculate auxiliary vectors.
+        eye_left = np.mean(lm_eye_left, axis=0)
+        eye_right = np.mean(lm_eye_right, axis=0)
+        eye_avg = (eye_left + eye_right) * 0.5
+        eye_to_eye = eye_right - eye_left
+        mouth_left = lm_mouth_outer[0]
+        mouth_right = lm_mouth_outer[6]
+        mouth_avg = (mouth_left + mouth_right) * 0.5
+        eye_to_mouth = mouth_avg - eye_avg
+
+        # Choose oriented crop rectangle.
+        x = eye_to_eye - np.flipud(eye_to_mouth) * [-1, 1]  # Addition of binocular difference and double mouth difference
+        x /= np.hypot(*x)   # hypot函数计算直角三角形的斜边长,用斜边长对三角形两条直边做归一化
+        x *= max(np.hypot(*eye_to_eye) * 2.0, np.hypot(*eye_to_mouth) * 1.8)    # 双眼差和眼嘴差,选较大的作为基准尺度
+        y = np.flipud(x) * [-1, 1]
+        c = eye_avg + eye_to_mouth * 0.1
+        quad = np.stack([c - x - y, c - x + y, c + x + y, c + x - y])   # 定义四边形,以面部基准位置为中心上下左右平移得到四个顶点
+        qsize = np.hypot(*x) * 2    # 定义四边形的大小(边长),为基准尺度的2倍
+
+        # Shrink.
+        # 如果计算出的四边形太大了,就按比例缩小它
+        shrink = int(np.floor(qsize / output_size * 0.5))
+        if shrink > 1:
+            rsize = (int(np.rint(float(img.size[0]) / shrink)), int(np.rint(float(img.size[1]) / shrink)))
+            img = img.resize(rsize, Image.ANTIALIAS)
+            quad /= shrink
+            qsize /= shrink
+        else:
+            rsize = (int(np.rint(float(img.size[0]))), int(np.rint(float(img.size[1]))))
+
+        # Crop.
+        border = max(int(np.rint(qsize * 0.1)), 3)
+        crop = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+                int(np.ceil(max(quad[:, 1]))))
+        crop = (max(crop[0] - border, 0), max(crop[1] - border, 0), min(crop[2] + border, img.size[0]),
+                min(crop[3] + border, img.size[1]))
+        if crop[2] - crop[0] < img.size[0] or crop[3] - crop[1] < img.size[1]:
+            # img = img.crop(crop)
+            quad -= crop[0:2]
+
+        # Pad.
+        pad = (int(np.floor(min(quad[:, 0]))), int(np.floor(min(quad[:, 1]))), int(np.ceil(max(quad[:, 0]))),
+               int(np.ceil(max(quad[:, 1]))))
+        pad = (max(-pad[0] + border, 0), max(-pad[1] + border, 0), max(pad[2] - img.size[0] + border, 0),
+               max(pad[3] - img.size[1] + border, 0))
+        # if enable_padding and max(pad) > border - 4:
+        #     pad = np.maximum(pad, int(np.rint(qsize * 0.3)))
+        #     img = np.pad(np.float32(img), ((pad[1], pad[3]), (pad[0], pad[2]), (0, 0)), 'reflect')
+        #     h, w, _ = img.shape
+        #     y, x, _ = np.ogrid[:h, :w, :1]
+        #     mask = np.maximum(1.0 - np.minimum(np.float32(x) / pad[0], np.float32(w - 1 - x) / pad[2]),
+        #                       1.0 - np.minimum(np.float32(y) / pad[1], np.float32(h - 1 - y) / pad[3]))
+        #     blur = qsize * 0.02
+        #     img += (scipy.ndimage.gaussian_filter(img, [blur, blur, 0]) - img) * np.clip(mask * 3.0 + 1.0, 0.0, 1.0)
+        #     img += (np.median(img, axis=(0, 1)) - img) * np.clip(mask, 0.0, 1.0)
+        #     img = Image.fromarray(np.uint8(np.clip(np.rint(img), 0, 255)), 'RGB')
+        #     quad += pad[:2]
+
+        # Transform.
+        quad = (quad + 0.5).flatten()
+        lx = max(min(quad[0], quad[2]), 0)
+        ly = max(min(quad[1], quad[7]), 0)
+        rx = min(max(quad[4], quad[6]), img.size[0])
+        ry = min(max(quad[3], quad[5]), img.size[0])
+        # img = img.transform((transform_size, transform_size), Image.QUAD, (quad + 0.5).flatten(),
+        #                     Image.BILINEAR)
+        # if output_size < transform_size:
+        #     img = img.resize((output_size, output_size), Image.ANTIALIAS)
+
+        # Save aligned image.
+        return rsize, crop, [lx, ly, rx, ry]
+
+    # def crop(self, img_np_list):
+    #     for _i in range(len(img_np_list)):
+    #         img_np = img_np_list[_i]
+    #         lm = self.get_landmark(img_np)
+    #         if lm is None:
+    #             return None
+    #         crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=512)
+    #         clx, cly, crx, cry = crop
+    #         lx, ly, rx, ry = quad
+    #         lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+        
+    #         _inp = img_np_list[_i]
+    #         _inp = _inp[cly:cry, clx:crx]
+    #         _inp = _inp[ly:ry, lx:rx]
+    #         img_np_list[_i] = _inp
+    #     return img_np_list
+    
+    def crop(self, img_np_list, still=False, xsize=512):    # first frame for all video
+        img_np = img_np_list[0]
+        lm = self.get_landmark(img_np)
+        if lm is None:
+            raise 'can not detect the landmark from source image'
+        rsize, crop, quad = self.align_face(img=Image.fromarray(img_np), lm=lm, output_size=xsize)
+        clx, cly, crx, cry = crop
+        lx, ly, rx, ry = quad
+        lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+        for _i in range(len(img_np_list)):
+            _inp = img_np_list[_i]
+            _inp = cv2.resize(_inp, (rsize[0], rsize[1]))
+            _inp = _inp[cly:cry, clx:crx]
+            # cv2.imwrite('test1.jpg', _inp)
+            if not still:
+                _inp = _inp[ly:ry, lx:rx]
+            # cv2.imwrite('test2.jpg', _inp)
+            img_np_list[_i] = _inp
+        return img_np_list, crop, quad
+
+
+def read_video(filename, uplimit=100):
+    frames = []
+    cap = cv2.VideoCapture(filename)
+    cnt = 0
+    while cap.isOpened():
+        ret, frame = cap.read()
+        if ret:
+            frame = cv2.resize(frame, (512, 512))
+            frames.append(frame)
+        else:
+            break
+        cnt += 1
+        if cnt >= uplimit:
+            break
+    cap.release()
+    assert len(frames) > 0, f'{filename}: video with no frames!'
+    return frames
+
+
+def create_video(video_name, frames, fps=25, video_format='.mp4', resize_ratio=1):
+    # video_name = os.path.dirname(image_folder) + video_format
+    # img_list = glob.glob1(image_folder, 'frame*')
+    # img_list.sort()
+    # frame = cv2.imread(os.path.join(image_folder, img_list[0]))
+    # frame = cv2.resize(frame, (0, 0), fx=resize_ratio, fy=resize_ratio)
+    # height, width, layers = frames[0].shape
+    height, width, layers = 512, 512, 3
+    if video_format == '.mp4':
+        fourcc = cv2.VideoWriter_fourcc(*'mp4v')
+    elif video_format == '.avi':
+        fourcc = cv2.VideoWriter_fourcc(*'XVID')
+    video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
+    for _frame in frames:
+        _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
+        video.write(_frame)
+
+def create_images(video_name, frames):
+    height, width, layers = 512, 512, 3
+    images_dir = video_name.split('.')[0]
+    os.makedirs(images_dir, exist_ok=True)
+    for i, _frame in enumerate(frames):
+        _frame = cv2.resize(_frame, (height, width), interpolation=cv2.INTER_LINEAR)
+        _frame_path = os.path.join(images_dir, str(i)+'.jpg')
+        cv2.imwrite(_frame_path, _frame)
+
+def run(data):
+    filename, opt, device = data
+    os.environ['CUDA_VISIBLE_DEVICES'] = device
+    croper = Croper()
+
+    frames = read_video(filename, uplimit=opt.uplimit)
+    name = filename.split('/')[-1]  # .split('.')[0]
+    name = os.path.join(opt.output_dir, name)
+
+    frames = croper.crop(frames)
+    if frames is None:
+        print(f'{name}: detect no face. should removed')
+        return
+    # create_video(name, frames)
+    create_images(name, frames)
+
+
+def get_data_path(video_dir):
+    eg_video_files = ['/apdcephfs/share_1290939/quincheng/datasets/HDTF/backup_fps25/WDA_KatieHill_000.mp4']
+    # filenames = list()
+    # VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
+    # VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
+    # extensions = VIDEO_EXTENSIONS
+    # for ext in extensions:
+    #     filenames = sorted(glob.glob(f'{opt.input_dir}/**/*.{ext}'))
+    # print('Total number of videos:', len(filenames))
+    return eg_video_files
+
+
+def get_wra_data_path(video_dir):
+    if opt.option == 'video':
+        videos_path = sorted(glob.glob(f'{video_dir}/*.mp4'))
+    elif opt.option == 'image':
+        videos_path = sorted(glob.glob(f'{video_dir}/*/'))
+    else:
+        raise NotImplementedError
+    print('Example videos: ', videos_path[:2])
+    return videos_path
+
+
+if __name__ == '__main__':
+    set_start_method('spawn')
+    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
+    parser.add_argument('--input_dir', type=str, help='the folder of the input files')
+    parser.add_argument('--output_dir', type=str, help='the folder of the output files')
+    parser.add_argument('--device_ids', type=str, default='0,1')
+    parser.add_argument('--workers', type=int, default=8)
+    parser.add_argument('--uplimit', type=int, default=500)
+    parser.add_argument('--option', type=str, default='video')
+
+    root = '/apdcephfs/share_1290939/quincheng/datasets/HDTF'
+    cmd = f'--input_dir {root}/backup_fps25_first20s_sync/ ' \
+          f'--output_dir {root}/crop512_stylegan_firstframe_sync/ ' \
+          '--device_ids 0 ' \
+          '--workers 8 ' \
+          '--option video ' \
+          '--uplimit 500 '
+    opt = parser.parse_args(cmd.split())
+    # filenames = get_data_path(opt.input_dir)
+    filenames = get_wra_data_path(opt.input_dir)
+    os.makedirs(opt.output_dir, exist_ok=True)
+    print(f'Video numbers: {len(filenames)}')
+    pool = Pool(opt.workers)
+    args_list = cycle([opt])
+    device_ids = opt.device_ids.split(",")
+    device_ids = cycle(device_ids)
+    for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
+        None
diff --git a/src/utils/face_enhancer.py b/src/utils/face_enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fde2b42b7feb8222474be1ec8a324ab0dec19a7
--- /dev/null
+++ b/src/utils/face_enhancer.py
@@ -0,0 +1,96 @@
+import os
+import torch 
+
+from gfpgan import GFPGANer
+
+from tqdm import tqdm
+
+from src.utils.videoio import load_video_to_cv2
+
+import cv2
+
+
+
+def enhancer(images, method='gfpgan', bg_upsampler='realesrgan'):
+    print('face enhancer....')
+    if os.path.isfile(images): # handle video to images
+        
+        images = load_video_to_cv2(images)
+
+    # ------------------------ set up GFPGAN restorer ------------------------
+    if  method == 'gfpgan':
+        arch = 'clean'
+        channel_multiplier = 2
+        model_name = 'GFPGANv1.4'
+        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth'
+    elif method == 'RestoreFormer':
+        arch = 'RestoreFormer'
+        channel_multiplier = 2
+        model_name = 'RestoreFormer'
+        url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth'
+    elif method == 'codeformer': # TODO:
+        arch = 'CodeFormer'
+        channel_multiplier = 2
+        model_name = 'CodeFormer'
+        url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
+    else:
+        raise ValueError(f'Wrong model version {method}.')
+
+
+    # ------------------------ set up background upsampler ------------------------
+    if bg_upsampler == 'realesrgan':
+        if not torch.cuda.is_available():  # CPU
+            import warnings
+            warnings.warn('The unoptimized RealESRGAN is slow on CPU. We do not use it. '
+                          'If you really want to use it, please modify the corresponding codes.')
+            bg_upsampler = None
+        else:
+            from basicsr.archs.rrdbnet_arch import RRDBNet
+            from realesrgan import RealESRGANer
+            model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
+            bg_upsampler = RealESRGANer(
+                scale=2,
+                model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
+                model=model,
+                tile=400,
+                tile_pad=10,
+                pre_pad=0,
+                half=True)  # need to set False in CPU mode
+    else:
+        bg_upsampler = None
+
+    # determine model paths
+    model_path = os.path.join('gfpgan/weights', model_name + '.pth')
+    
+    if not os.path.isfile(model_path):
+        model_path = os.path.join('checkpoints', model_name + '.pth')
+    
+    if not os.path.isfile(model_path):
+        # download pre-trained models from url
+        model_path = url
+
+    restorer = GFPGANer(
+        model_path=model_path,
+        upscale=2,
+        arch=arch,
+        channel_multiplier=channel_multiplier,
+        bg_upsampler=bg_upsampler)
+
+    # ------------------------ restore ------------------------
+    restored_img = [] 
+    for idx in tqdm(range(len(images)), 'Face Enhancer:'):
+        
+        img = cv2.cvtColor(images[idx], cv2.COLOR_RGB2BGR)
+        
+        # restore faces and background if necessary
+        cropped_faces, restored_faces, r_img = restorer.enhance(
+            img,
+            has_aligned=False,
+            only_center_face=False,
+            paste_back=True)
+        
+        r_img = cv2.cvtColor(r_img, cv2.COLOR_BGR2RGB)
+        
+        restored_img += [r_img]
+       
+    return restored_img
diff --git a/src/utils/hparams.py b/src/utils/hparams.py
new file mode 100644
index 0000000000000000000000000000000000000000..83c312d767c35b9adc988157243efc02129fdb84
--- /dev/null
+++ b/src/utils/hparams.py
@@ -0,0 +1,160 @@
+from glob import glob
+import os
+
+class HParams:
+	def __init__(self, **kwargs):
+		self.data = {}
+
+		for key, value in kwargs.items():
+			self.data[key] = value
+
+	def __getattr__(self, key):
+		if key not in self.data:
+			raise AttributeError("'HParams' object has no attribute %s" % key)
+		return self.data[key]
+
+	def set_hparam(self, key, value):
+		self.data[key] = value
+
+
+# Default hyperparameters
+hparams = HParams(
+	num_mels=80,  # Number of mel-spectrogram channels and local conditioning dimensionality
+	#  network
+	rescale=True,  # Whether to rescale audio prior to preprocessing
+	rescaling_max=0.9,  # Rescaling value
+	
+	# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+	# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+	# Does not work if n_ffit is not multiple of hop_size!!
+	use_lws=False,
+	
+	n_fft=800,  # Extra window size is filled with 0 paddings to match this parameter
+	hop_size=200,  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+	win_size=800,  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+	sample_rate=16000,  # 16000Hz (corresponding to librispeech) (sox --i <filename>)
+	
+	frame_shift_ms=None,  # Can replace hop_size parameter. (Recommended: 12.5)
+	
+	# Mel and Linear spectrograms normalization/scaling and clipping
+	signal_normalization=True,
+	# Whether to normalize mel spectrograms to some predefined range (following below parameters)
+	allow_clipping_in_normalization=True,  # Only relevant if mel_normalization = True
+	symmetric_mels=True,
+	# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 
+	# faster and cleaner convergence)
+	max_abs_value=4.,
+	# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 
+	# be too big to avoid gradient explosion, 
+	# not too small for fast convergence)
+	# Contribution by @begeekmyfriend
+	# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 
+	# levels. Also allows for better G&L phase reconstruction)
+	preemphasize=True,  # whether to apply filter
+	preemphasis=0.97,  # filter coefficient.
+	
+	# Limits
+	min_level_db=-100,
+	ref_level_db=20,
+	fmin=55,
+	# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 
+	# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+	fmax=7600,  # To be increased/reduced depending on data.
+
+	###################### Our training parameters #################################
+	img_size=96,
+	fps=25,
+	
+	batch_size=16,
+	initial_learning_rate=1e-4,
+	nepochs=300000,  ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+	num_workers=20,
+	checkpoint_interval=3000,
+	eval_interval=3000,
+	writer_interval=300,
+    save_optimizer_state=True,
+
+    syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 
+	syncnet_batch_size=64,
+	syncnet_lr=1e-4,
+	syncnet_eval_interval=1000,
+	syncnet_checkpoint_interval=10000,
+
+	disc_wt=0.07,
+	disc_initial_learning_rate=1e-4,
+)
+
+
+
+# Default hyperparameters
+hparamsdebug = HParams(
+	num_mels=80,  # Number of mel-spectrogram channels and local conditioning dimensionality
+	#  network
+	rescale=True,  # Whether to rescale audio prior to preprocessing
+	rescaling_max=0.9,  # Rescaling value
+	
+	# Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
+	# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
+	# Does not work if n_ffit is not multiple of hop_size!!
+	use_lws=False,
+	
+	n_fft=800,  # Extra window size is filled with 0 paddings to match this parameter
+	hop_size=200,  # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
+	win_size=800,  # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
+	sample_rate=16000,  # 16000Hz (corresponding to librispeech) (sox --i <filename>)
+	
+	frame_shift_ms=None,  # Can replace hop_size parameter. (Recommended: 12.5)
+	
+	# Mel and Linear spectrograms normalization/scaling and clipping
+	signal_normalization=True,
+	# Whether to normalize mel spectrograms to some predefined range (following below parameters)
+	allow_clipping_in_normalization=True,  # Only relevant if mel_normalization = True
+	symmetric_mels=True,
+	# Whether to scale the data to be symmetric around 0. (Also multiplies the output range by 2, 
+	# faster and cleaner convergence)
+	max_abs_value=4.,
+	# max absolute value of data. If symmetric, data will be [-max, max] else [0, max] (Must not 
+	# be too big to avoid gradient explosion, 
+	# not too small for fast convergence)
+	# Contribution by @begeekmyfriend
+	# Spectrogram Pre-Emphasis (Lfilter: Reduce spectrogram noise and helps model certitude 
+	# levels. Also allows for better G&L phase reconstruction)
+	preemphasize=True,  # whether to apply filter
+	preemphasis=0.97,  # filter coefficient.
+	
+	# Limits
+	min_level_db=-100,
+	ref_level_db=20,
+	fmin=55,
+	# Set this to 55 if your speaker is male! if female, 95 should help taking off noise. (To 
+	# test depending on dataset. Pitch info: male~[65, 260], female~[100, 525])
+	fmax=7600,  # To be increased/reduced depending on data.
+
+	###################### Our training parameters #################################
+	img_size=96,
+	fps=25,
+	
+	batch_size=2,
+	initial_learning_rate=1e-3,
+	nepochs=100000,  ### ctrl + c, stop whenever eval loss is consistently greater than train loss for ~10 epochs
+	num_workers=0,
+	checkpoint_interval=10000,
+	eval_interval=10,
+	writer_interval=5,
+    save_optimizer_state=True,
+
+    syncnet_wt=0.0, # is initially zero, will be set automatically to 0.03 later. Leads to faster convergence. 
+	syncnet_batch_size=64,
+	syncnet_lr=1e-4,
+	syncnet_eval_interval=10000,
+	syncnet_checkpoint_interval=10000,
+
+	disc_wt=0.07,
+	disc_initial_learning_rate=1e-4,
+)
+
+
+def hparams_debug_string():
+	values = hparams.values()
+	hp = ["  %s: %s" % (name, values[name]) for name in sorted(values) if name != "sentences"]
+	return "Hyperparameters:\n" + "\n".join(hp)
diff --git a/src/utils/paste_pic.py b/src/utils/paste_pic.py
new file mode 100644
index 0000000000000000000000000000000000000000..96952c0a65360399182802d89005305f67d3eb85
--- /dev/null
+++ b/src/utils/paste_pic.py
@@ -0,0 +1,66 @@
+import cv2, os
+import numpy as np
+from tqdm import tqdm
+import uuid
+
+from src.utils.videoio import save_video_with_watermark 
+
+def paste_pic(video_path, pic_path, crop_info, new_audio_path, full_video_path):
+
+    if not os.path.isfile(pic_path):
+        raise ValueError('pic_path must be a valid path to video/image file')
+    elif pic_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+        # loader for first frame
+        full_img = cv2.imread(pic_path)
+    else:
+        # loader for videos
+        video_stream = cv2.VideoCapture(pic_path)
+        fps = video_stream.get(cv2.CAP_PROP_FPS)
+        full_frames = [] 
+        while 1:
+            still_reading, frame = video_stream.read()
+            if not still_reading:
+                video_stream.release()
+                break 
+            break 
+        full_img = frame
+    frame_h = full_img.shape[0]
+    frame_w = full_img.shape[1]
+
+    video_stream = cv2.VideoCapture(video_path)
+    fps = video_stream.get(cv2.CAP_PROP_FPS)
+    crop_frames = []
+    while 1:
+        still_reading, frame = video_stream.read()
+        if not still_reading:
+            video_stream.release()
+            break
+        crop_frames.append(frame)
+    
+    if len(crop_info) != 3:
+        print("you didn't crop the image")
+        return
+    else:
+        r_w, r_h = crop_info[0]
+        clx, cly, crx, cry = crop_info[1]
+        lx, ly, rx, ry = crop_info[2]
+        lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+        # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+        # oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+        oy1, oy2, ox1, ox2 = cly, cry, clx, crx
+
+
+    tmp_path = str(uuid.uuid4())+'.mp4'
+    out_tmp = cv2.VideoWriter(tmp_path, cv2.VideoWriter_fourcc(*'MP4V'), fps, (frame_w, frame_h))
+    for crop_frame in tqdm(crop_frames, 'seamlessClone:'):
+        p = cv2.resize(crop_frame.astype(np.uint8), (crx-clx, cry - cly)) 
+
+        mask = 255*np.ones(p.shape, p.dtype)
+        location = ((ox1+ox2) // 2, (oy1+oy2) // 2)
+        gen_img = cv2.seamlessClone(p, full_img, mask, location, cv2.NORMAL_CLONE)
+        out_tmp.write(gen_img)
+
+    out_tmp.release()
+
+    save_video_with_watermark(tmp_path, new_audio_path, full_video_path, watermark=False)
+    os.remove(tmp_path)
diff --git a/src/utils/preprocess.py b/src/utils/preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea8e5c2c27369b754bc5c2defb12290cfedac968
--- /dev/null
+++ b/src/utils/preprocess.py
@@ -0,0 +1,168 @@
+import numpy as np
+import cv2, os, sys, torch
+from tqdm import tqdm
+from PIL import Image 
+
+# 3dmm extraction
+from src.face3d.util.preprocess import align_img
+from src.face3d.util.load_mats import load_lm3d
+from src.face3d.models import networks
+
+try:
+    import webui
+    from src.face3d.extract_kp_videos_safe import KeypointExtractor
+    assert torch.cuda.is_available() == True
+except:
+    from src.face3d.extract_kp_videos import KeypointExtractor
+
+from scipy.io import loadmat, savemat
+from src.utils.croper import Croper
+
+import warnings 
+warnings.filterwarnings("ignore")
+
+def split_coeff(coeffs):
+        """
+        Return:
+            coeffs_dict     -- a dict of torch.tensors
+
+        Parameters:
+            coeffs          -- torch.tensor, size (B, 256)
+        """
+        id_coeffs = coeffs[:, :80]
+        exp_coeffs = coeffs[:, 80: 144]
+        tex_coeffs = coeffs[:, 144: 224]
+        angles = coeffs[:, 224: 227]
+        gammas = coeffs[:, 227: 254]
+        translations = coeffs[:, 254:]
+        return {
+            'id': id_coeffs,
+            'exp': exp_coeffs,
+            'tex': tex_coeffs,
+            'angle': angles,
+            'gamma': gammas,
+            'trans': translations
+        }
+
+
+class CropAndExtract():
+    def __init__(self, path_of_lm_croper, path_of_net_recon_model, dir_of_BFM_fitting, device):
+
+        self.croper = Croper(path_of_lm_croper)
+        self.kp_extractor = KeypointExtractor(device)
+        self.net_recon = networks.define_net_recon(net_recon='resnet50', use_last_fc=False, init_path='').to(device)
+        checkpoint = torch.load(path_of_net_recon_model, map_location=torch.device(device))    
+        self.net_recon.load_state_dict(checkpoint['net_recon'])
+        self.net_recon.eval()
+        self.lm3d_std = load_lm3d(dir_of_BFM_fitting)
+        self.device = device
+    
+    def generate(self, input_path, save_dir, crop_or_resize='crop', source_image_flag=False):
+
+        pic_size = 256
+        pic_name = os.path.splitext(os.path.split(input_path)[-1])[0]  
+
+        landmarks_path =  os.path.join(save_dir, pic_name+'_landmarks.txt') 
+        coeff_path =  os.path.join(save_dir, pic_name+'.mat')  
+        png_path =  os.path.join(save_dir, pic_name+'.png')  
+
+        #load input
+        if not os.path.isfile(input_path):
+            raise ValueError('input_path must be a valid path to video/image file')
+        elif input_path.split('.')[-1] in ['jpg', 'png', 'jpeg']:
+            # loader for first frame
+            full_frames = [cv2.imread(input_path)]
+            fps = 25
+        else:
+            # loader for videos
+            video_stream = cv2.VideoCapture(input_path)
+            fps = video_stream.get(cv2.CAP_PROP_FPS)
+            full_frames = [] 
+            while 1:
+                still_reading, frame = video_stream.read()
+                if not still_reading:
+                    video_stream.release()
+                    break 
+                full_frames.append(frame) 
+                if source_image_flag:
+                    break
+
+        x_full_frames= [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)  for frame in full_frames] 
+
+        #### crop images as the 
+        if crop_or_resize.lower() == 'crop': # default crop
+            x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512)
+            clx, cly, crx, cry = crop
+            lx, ly, rx, ry = quad
+            lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+            oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+            crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+        elif crop_or_resize.lower() == 'full':
+            x_full_frames, crop, quad = self.croper.crop(x_full_frames, still=True, xsize=512)
+            clx, cly, crx, cry = crop
+            lx, ly, rx, ry = quad
+            lx, ly, rx, ry = int(lx), int(ly), int(rx), int(ry)
+            oy1, oy2, ox1, ox2 = cly+ly, cly+ry, clx+lx, clx+rx
+            crop_info = ((ox2 - ox1, oy2 - oy1), crop, quad)
+        else: # resize mode
+            oy1, oy2, ox1, ox2 = 0, x_full_frames[0].shape[0], 0, x_full_frames[0].shape[1] 
+            crop_info = ((ox2 - ox1, oy2 - oy1), None, None)
+
+        frames_pil = [Image.fromarray(cv2.resize(frame,(pic_size, pic_size))) for frame in x_full_frames]
+        if len(frames_pil) == 0:
+            print('No face is detected in the input file')
+            return None, None
+
+        # save crop info
+        for frame in frames_pil:
+            cv2.imwrite(png_path, cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
+
+        # 2. get the landmark according to the detected face. 
+        if not os.path.isfile(landmarks_path): 
+            lm = self.kp_extractor.extract_keypoint(frames_pil, landmarks_path)
+        else:
+            print(' Using saved landmarks.')
+            lm = np.loadtxt(landmarks_path).astype(np.float32)
+            lm = lm.reshape([len(x_full_frames), -1, 2])
+
+        if not os.path.isfile(coeff_path):
+            # load 3dmm paramter generator from Deep3DFaceRecon_pytorch 
+            video_coeffs, full_coeffs = [],  []
+            for idx in tqdm(range(len(frames_pil)), desc='3DMM Extraction In Video:'):
+                frame = frames_pil[idx]
+                W,H = frame.size
+                lm1 = lm[idx].reshape([-1, 2])
+            
+                if np.mean(lm1) == -1:
+                    lm1 = (self.lm3d_std[:, :2]+1)/2.
+                    lm1 = np.concatenate(
+                        [lm1[:, :1]*W, lm1[:, 1:2]*H], 1
+                    )
+                else:
+                    lm1[:, -1] = H - 1 - lm1[:, -1]
+
+                trans_params, im1, lm1, _ = align_img(frame, lm1, self.lm3d_std)
+ 
+                trans_params = np.array([float(item) for item in np.hsplit(trans_params, 5)]).astype(np.float32)
+                im_t = torch.tensor(np.array(im1)/255., dtype=torch.float32).permute(2, 0, 1).to(self.device).unsqueeze(0)
+                
+                with torch.no_grad():
+                    full_coeff = self.net_recon(im_t)
+                    coeffs = split_coeff(full_coeff)
+
+                pred_coeff = {key:coeffs[key].cpu().numpy() for key in coeffs}
+ 
+                pred_coeff = np.concatenate([
+                    pred_coeff['exp'], 
+                    pred_coeff['angle'],
+                    pred_coeff['trans'],
+                    trans_params[2:][None],
+                    ], 1)
+                video_coeffs.append(pred_coeff)
+                full_coeffs.append(full_coeff.cpu().numpy())
+
+            semantic_npy = np.array(video_coeffs)[:,0] 
+
+            savemat(coeff_path, {'coeff_3dmm': semantic_npy, 'full_3dmm': np.array(full_coeffs)[0]})
+
+        return coeff_path, png_path, crop_info
diff --git a/src/utils/text2speech.py b/src/utils/text2speech.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dba71bbb5ae1dbe6616d06e3c63bec9974041e0
--- /dev/null
+++ b/src/utils/text2speech.py
@@ -0,0 +1,21 @@
+import os
+import tempfile
+from TTS.api import TTS
+
+
+
+class TTSTalker():
+    def __init__(self) -> None:
+        model_name = TTS.list_models()[0]
+        self.tts = TTS(model_name)
+
+    def test(self, text, language='en'):
+
+        tempf  = tempfile.NamedTemporaryFile(
+                delete = False,
+                suffix = ('.'+'wav'),
+            )
+
+        self.tts.tts_to_file(text, speaker=self.tts.speakers[0], language=language, file_path=tempf.name)
+
+        return tempf.name
\ No newline at end of file
diff --git a/src/utils/videoio.py b/src/utils/videoio.py
new file mode 100644
index 0000000000000000000000000000000000000000..369fea84bd2c8367d6d0e3b3d6bc165fcb86b1c5
--- /dev/null
+++ b/src/utils/videoio.py
@@ -0,0 +1,79 @@
+import shutil
+import uuid
+
+import os
+
+import cv2
+
+def load_video_to_cv2(input_path):
+    video_stream = cv2.VideoCapture(input_path)
+    fps = video_stream.get(cv2.CAP_PROP_FPS)
+    full_frames = [] 
+    while 1:
+        still_reading, frame = video_stream.read()
+        if not still_reading:
+            video_stream.release()
+            break 
+        full_frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+    return full_frames
+
+def save_video_with_watermark(video, audio, save_path, watermark=False):
+    temp_file = str(uuid.uuid4())+'.mp4'
+    from moviepy.editor import VideoFileClip, AudioFileClip 
+    import tempfile
+    import base64
+
+    # Load audio and video files
+    audio_clip = AudioFileClip(audio) 
+    video_clip = VideoFileClip(video) 
+    # Combine audio and video
+    final_clip = video_clip.set_audio(audio_clip)
+
+    temp_file = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
+    temp_file.close()  # Close the file so it can be opened by MoviePy
+    temp_file_path = temp_file.name
+    final_clip.write_videofile(temp_file_path)
+
+    with open(temp_file_path, 'rb') as f:
+        video_content = f.read()
+
+    base64_video = base64.b64encode(video_content).decode('utf-8')
+
+    return base64_video, temp_file_path
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+    # cmd = r'ffmpeg -y -i "%s" -i "%s" -vcodec copy "%s"' % (video, audio, temp_file)
+    # os.system(cmd)
+
+    # if watermark is False:
+    #     shutil.move(temp_file, save_path)
+    # else:
+    #     # watermark
+    #     try:
+    #         ##### check if stable-diffusion-webui
+    #         import webui
+    #         from modules import paths
+    #         watarmark_path = paths.script_path+"/extensions/SadTalker/docs/sadtalker_logo.png"
+    #     except:
+    #         # get the root path of sadtalker.
+    #         dir_path = os.path.dirname(os.path.realpath(__file__))
+    #         watarmark_path = dir_path+"/../../docs/sadtalker_logo.png"
+
+    #     cmd = r'ffmpeg -y -hide_banner -i "%s" -i "%s" -filter_complex "[1]scale=100:-1[wm];[0][wm]overlay=(main_w-overlay_w)-10:10" "%s"' % (temp_file, watarmark_path, save_path)
+    #     os.system(cmd)
+    #     os.remove(temp_file)
\ No newline at end of file